Browse Source

fix name_scope stack can't restore when name is empty.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
de7a6941a1
7 changed files with 101 additions and 35 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +7
    -15
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  3. +6
    -0
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  4. +12
    -5
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  5. +46
    -6
      src/TensorFlowNET.Core/Variables/_VariableStore.cs
  6. +11
    -3
      src/TensorFlowNET.Core/Variables/tf.variable.cs
  7. +16
    -5
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs

+ 3
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -197,7 +197,9 @@ namespace Tensorflow
{ {
string new_stack = ""; string new_stack = "";


if (name.EndsWith("/"))
if (string.IsNullOrEmpty(name))
new_stack = "";
else if (name.EndsWith("/"))
new_stack = ops._name_from_scope_name(name); new_stack = ops._name_from_scope_name(name);
else else
new_stack = unique_name(name); new_stack = unique_name(name);


+ 7
- 15
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -9,32 +9,24 @@ namespace Tensorflow
{ {
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "")
{ {
Tensor output = null;

dtype = dtype.as_base_dtype(); dtype = dtype.as_base_dtype();
Python.with(new ops.name_scope(name, "zeros", shape), self =>
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "zeros", shape), scope =>
{ {
name = self as ops.name_scope;
name = scope;
switch (dtype) switch (dtype)
{ {
case TF_DataType.TF_BOOL: case TF_DataType.TF_BOOL:
output = _constant_if_small(false, shape, dtype, name);
break;
return _constant_if_small(false, shape, dtype, name);
case TF_DataType.TF_DOUBLE: case TF_DataType.TF_DOUBLE:
output = _constant_if_small(0.0D, shape, dtype, name);
break;
return _constant_if_small(0.0D, shape, dtype, name);
case TF_DataType.TF_FLOAT: case TF_DataType.TF_FLOAT:
output = _constant_if_small(0.0F, shape, dtype, name);
break;
return _constant_if_small(0.0F, shape, dtype, name);
case TF_DataType.TF_INT32: case TF_DataType.TF_INT32:
output = _constant_if_small(0, shape, dtype, name);
break;
return _constant_if_small(0, shape, dtype, name);
default: default:
break;
throw new TypeError("can't find type for zeros");
} }
}); });

return output;
} }


private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name) private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name)


+ 6
- 0
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -31,6 +31,12 @@ TensorFlow 1.13 RC.</PackageReleaseNotes>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup> </PropertyGroup>


<ItemGroup>
<Compile Remove="runtimes\**" />
<EmbeddedResource Remove="runtimes\**" />
<None Remove="runtimes\**" />
</ItemGroup>

<ItemGroup> <ItemGroup>
<None Remove="Protobuf\README.md" /> <None Remove="Protobuf\README.md" />
</ItemGroup> </ItemGroup>


+ 12
- 5
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -13,28 +13,35 @@ namespace Tensorflow
private TF_DataType _dtype; private TF_DataType _dtype;
public string name { get; set; } public string name { get; set; }


public VariableScope()
public VariableScope(TF_DataType dtype = TF_DataType.TF_FLOAT)
{ {
_reuse = _ReuseMode.AUTO_REUSE; _reuse = _ReuseMode.AUTO_REUSE;
_dtype = dtype;
} }


public RefVariable get_variable(_VariableStore var_store, public RefVariable get_variable(_VariableStore var_store,
string name, string name,
TensorShape shape = null, TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool? trainable = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation= VariableAggregation.NONE) VariableAggregation aggregation= VariableAggregation.NONE)
{ {
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name;
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(""), scope =>
return Python.with<ops.name_scope, RefVariable>(new ops.name_scope(""), scope =>
{ {
if (dtype == TF_DataType.DtInvalid) if (dtype == TF_DataType.DtInvalid)
dtype = _dtype; dtype = _dtype;


return var_store.get_variable(full_name);

return var_store.get_variable(full_name,
shape: shape,
dtype: dtype,
initializer: initializer,
trainable: trainable,
synchronization: synchronization,
aggregation: aggregation);
}); });

} }
} }
} }

+ 46
- 6
src/TensorFlowNET.Core/Variables/_VariableStore.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
TensorShape shape = null, TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT, TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null, IInitializer initializer = null,
bool trainable = false,
bool? trainable = null,
bool validate_shape = true, bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE) VariableAggregation aggregation = VariableAggregation.NONE)
@@ -44,14 +44,23 @@ namespace Tensorflow


private RefVariable _true_getter(string name, private RefVariable _true_getter(string name,
TensorShape shape = null, TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null, IInitializer initializer = null,
bool trainable = false,
bool? trainable = null,
bool validate_shape = true, bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE) VariableAggregation aggregation = VariableAggregation.NONE)
{ {
return _get_single_variable(name: name);
bool is_scalar = shape.NDim == 0;

return _get_single_variable(name: name,
shape: shape,
dtype: dtype,
initializer: initializer,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
} }


private RefVariable _get_single_variable(string name, private RefVariable _get_single_variable(string name,
@@ -59,11 +68,14 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null, IInitializer initializer = null,
bool reuse = false, bool reuse = false,
bool trainable = false,
bool? trainable = null,
bool validate_shape = false, bool validate_shape = false,
bool? use_resource = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE) VariableAggregation aggregation = VariableAggregation.NONE)
{ {
bool initializing_from_value = false;

if (_vars.ContainsKey(name)) if (_vars.ContainsKey(name))
{ {
if (!reuse) if (!reuse)
@@ -74,7 +86,35 @@ namespace Tensorflow
throw new NotImplementedException("_get_single_variable"); throw new NotImplementedException("_get_single_variable");
} }


throw new NotImplementedException("_get_single_variable");
Tensor init_val = null;
ops.init_scope();
{
if (initializing_from_value)
{

}
else
{
init_val = initializer.call(shape, dtype);
var variable_dtype = dtype.as_base_dtype();
}
}

// Create the variable.
if (use_resource == null)
use_resource = false;

var v = variable_scope.default_variable_creator(init_val,
name: name,
trainable: trainable,
dtype: TF_DataType.DtInvalid,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);

_vars[name] = v;

return v;
} }
} }
} }

+ 11
- 3
src/TensorFlowNET.Core/Variables/tf.variable.cs View File

@@ -12,14 +12,22 @@ namespace Tensorflow
return variables.variables_initializer(g.ToArray()); return variables.variables_initializer(g.ToArray());
} }


public static RefVariable get_variable(string name,
TensorShape shape = null,
public static RefVariable get_variable(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null, IInitializer initializer = null,
bool? trainable = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE) VariableAggregation aggregation = VariableAggregation.NONE)
{ {
var scope = variable_scope.get_variable_scope();
var store = variable_scope._get_default_variable_store(); var store = variable_scope._get_default_variable_store();
return variable_scope.get_variable_scope().get_variable(store, name, shape: shape);
return scope.get_variable(store,
name,
shape: shape,
dtype: dtype,
initializer: initializer,
trainable: trainable);
} }
} }
} }

+ 16
- 5
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -10,9 +10,16 @@ namespace Tensorflow
public static string _VARSCOPESTORE_KEY = "__varscope"; public static string _VARSCOPESTORE_KEY = "__varscope";
public static bool _DEFAULT_USE_RESOURCE = false; public static bool _DEFAULT_USE_RESOURCE = false;


public static RefVariable default_variable_creator(object initial_value, string name = "", TF_DataType dtype = TF_DataType.DtInvalid, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO)
public static RefVariable default_variable_creator(object initial_value,
string name = "",
bool? trainable = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool validate_shape = false,
bool ? use_resource = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{ {
var trainable = _get_trainable_value(synchronization);
trainable = _get_trainable_value(synchronization, trainable);
if (!use_resource.HasValue) if (!use_resource.HasValue)
{ {
use_resource = get_variable_scope().use_resource; use_resource = get_variable_scope().use_resource;
@@ -77,18 +84,22 @@ namespace Tensorflow
return ret; return ret;
} }


public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true)
public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true)
{ {
if (synchronization == VariableSynchronization.ON_READ) if (synchronization == VariableSynchronization.ON_READ)
{ {
if (trainable)
if (trainable.Value)
throw new ValueError("Synchronization value can be set to " + throw new ValueError("Synchronization value can be set to " +
"VariableSynchronization.ON_READ only for non-trainable variables. " + "VariableSynchronization.ON_READ only for non-trainable variables. " +
"You have specified trainable=True and " + "You have specified trainable=True and " +
"synchronization=VariableSynchronization.ON_READ."); "synchronization=VariableSynchronization.ON_READ.");
} }
else if (!trainable.HasValue)
{
trainable = true;
}
return trainable;
return trainable.Value;
} }
} }
} }

Loading…
Cancel
Save