Browse Source

fix name_scope stack can't restore when name is empty.

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
de7a6941a1
7 changed files with 101 additions and 35 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +7
    -15
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  3. +6
    -0
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  4. +12
    -5
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  5. +46
    -6
      src/TensorFlowNET.Core/Variables/_VariableStore.cs
  6. +11
    -3
      src/TensorFlowNET.Core/Variables/tf.variable.cs
  7. +16
    -5
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs

+ 3
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -197,7 +197,9 @@ namespace Tensorflow
{
string new_stack = "";

if (name.EndsWith("/"))
if (string.IsNullOrEmpty(name))
new_stack = "";
else if (name.EndsWith("/"))
new_stack = ops._name_from_scope_name(name);
else
new_stack = unique_name(name);


+ 7
- 15
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -9,32 +9,24 @@ namespace Tensorflow
{
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "")
{
Tensor output = null;

dtype = dtype.as_base_dtype();
Python.with(new ops.name_scope(name, "zeros", shape), self =>
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "zeros", shape), scope =>
{
name = self as ops.name_scope;
name = scope;
switch (dtype)
{
case TF_DataType.TF_BOOL:
output = _constant_if_small(false, shape, dtype, name);
break;
return _constant_if_small(false, shape, dtype, name);
case TF_DataType.TF_DOUBLE:
output = _constant_if_small(0.0D, shape, dtype, name);
break;
return _constant_if_small(0.0D, shape, dtype, name);
case TF_DataType.TF_FLOAT:
output = _constant_if_small(0.0F, shape, dtype, name);
break;
return _constant_if_small(0.0F, shape, dtype, name);
case TF_DataType.TF_INT32:
output = _constant_if_small(0, shape, dtype, name);
break;
return _constant_if_small(0, shape, dtype, name);
default:
break;
throw new TypeError("can't find type for zeros");
}
});

return output;
}

private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name)


+ 6
- 0
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -31,6 +31,12 @@ TensorFlow 1.13 RC.</PackageReleaseNotes>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<Compile Remove="runtimes\**" />
<EmbeddedResource Remove="runtimes\**" />
<None Remove="runtimes\**" />
</ItemGroup>

<ItemGroup>
<None Remove="Protobuf\README.md" />
</ItemGroup>


+ 12
- 5
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -13,28 +13,35 @@ namespace Tensorflow
private TF_DataType _dtype;
public string name { get; set; }

public VariableScope()
public VariableScope(TF_DataType dtype = TF_DataType.TF_FLOAT)
{
_reuse = _ReuseMode.AUTO_REUSE;
_dtype = dtype;
}

public RefVariable get_variable(_VariableStore var_store,
string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool? trainable = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation= VariableAggregation.NONE)
{
string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name;
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(""), scope =>
return Python.with<ops.name_scope, RefVariable>(new ops.name_scope(""), scope =>
{
if (dtype == TF_DataType.DtInvalid)
dtype = _dtype;

return var_store.get_variable(full_name);

return var_store.get_variable(full_name,
shape: shape,
dtype: dtype,
initializer: initializer,
trainable: trainable,
synchronization: synchronization,
aggregation: aggregation);
});

}
}
}

+ 46
- 6
src/TensorFlowNET.Core/Variables/_VariableStore.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = false,
bool? trainable = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
@@ -44,14 +44,23 @@ namespace Tensorflow

private RefVariable _true_getter(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
bool trainable = false,
bool? trainable = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
return _get_single_variable(name: name);
bool is_scalar = shape.NDim == 0;

return _get_single_variable(name: name,
shape: shape,
dtype: dtype,
initializer: initializer,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
}

private RefVariable _get_single_variable(string name,
@@ -59,11 +68,14 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool reuse = false,
bool trainable = false,
bool? trainable = null,
bool validate_shape = false,
bool? use_resource = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
bool initializing_from_value = false;

if (_vars.ContainsKey(name))
{
if (!reuse)
@@ -74,7 +86,35 @@ namespace Tensorflow
throw new NotImplementedException("_get_single_variable");
}

throw new NotImplementedException("_get_single_variable");
Tensor init_val = null;
ops.init_scope();
{
if (initializing_from_value)
{

}
else
{
init_val = initializer.call(shape, dtype);
var variable_dtype = dtype.as_base_dtype();
}
}

// Create the variable.
if (use_resource == null)
use_resource = false;

var v = variable_scope.default_variable_creator(init_val,
name: name,
trainable: trainable,
dtype: TF_DataType.DtInvalid,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);

_vars[name] = v;

return v;
}
}
}

+ 11
- 3
src/TensorFlowNET.Core/Variables/tf.variable.cs View File

@@ -12,14 +12,22 @@ namespace Tensorflow
return variables.variables_initializer(g.ToArray());
}

public static RefVariable get_variable(string name,
TensorShape shape = null,
public static RefVariable get_variable(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
bool? trainable = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
var scope = variable_scope.get_variable_scope();
var store = variable_scope._get_default_variable_store();
return variable_scope.get_variable_scope().get_variable(store, name, shape: shape);
return scope.get_variable(store,
name,
shape: shape,
dtype: dtype,
initializer: initializer,
trainable: trainable);
}
}
}

+ 16
- 5
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -10,9 +10,16 @@ namespace Tensorflow
public static string _VARSCOPESTORE_KEY = "__varscope";
public static bool _DEFAULT_USE_RESOURCE = false;

public static RefVariable default_variable_creator(object initial_value, string name = "", TF_DataType dtype = TF_DataType.DtInvalid, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO)
public static RefVariable default_variable_creator(object initial_value,
string name = "",
bool? trainable = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool validate_shape = false,
bool ? use_resource = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
var trainable = _get_trainable_value(synchronization);
trainable = _get_trainable_value(synchronization, trainable);
if (!use_resource.HasValue)
{
use_resource = get_variable_scope().use_resource;
@@ -77,18 +84,22 @@ namespace Tensorflow
return ret;
}

public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true)
public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true)
{
if (synchronization == VariableSynchronization.ON_READ)
{
if (trainable)
if (trainable.Value)
throw new ValueError("Synchronization value can be set to " +
"VariableSynchronization.ON_READ only for non-trainable variables. " +
"You have specified trainable=True and " +
"synchronization=VariableSynchronization.ON_READ.");
}
else if (!trainable.HasValue)
{
trainable = true;
}
return trainable;
return trainable.Value;
}
}
}

Loading…
Cancel
Save