| @@ -197,7 +197,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| string new_stack = ""; | string new_stack = ""; | ||||
| if (name.EndsWith("/")) | |||||
| if (string.IsNullOrEmpty(name)) | |||||
| new_stack = ""; | |||||
| else if (name.EndsWith("/")) | |||||
| new_stack = ops._name_from_scope_name(name); | new_stack = ops._name_from_scope_name(name); | ||||
| else | else | ||||
| new_stack = unique_name(name); | new_stack = unique_name(name); | ||||
| @@ -9,32 +9,24 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") | ||||
| { | { | ||||
| Tensor output = null; | |||||
| dtype = dtype.as_base_dtype(); | dtype = dtype.as_base_dtype(); | ||||
| Python.with(new ops.name_scope(name, "zeros", shape), self => | |||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "zeros", shape), scope => | |||||
| { | { | ||||
| name = self as ops.name_scope; | |||||
| name = scope; | |||||
| switch (dtype) | switch (dtype) | ||||
| { | { | ||||
| case TF_DataType.TF_BOOL: | case TF_DataType.TF_BOOL: | ||||
| output = _constant_if_small(false, shape, dtype, name); | |||||
| break; | |||||
| return _constant_if_small(false, shape, dtype, name); | |||||
| case TF_DataType.TF_DOUBLE: | case TF_DataType.TF_DOUBLE: | ||||
| output = _constant_if_small(0.0D, shape, dtype, name); | |||||
| break; | |||||
| return _constant_if_small(0.0D, shape, dtype, name); | |||||
| case TF_DataType.TF_FLOAT: | case TF_DataType.TF_FLOAT: | ||||
| output = _constant_if_small(0.0F, shape, dtype, name); | |||||
| break; | |||||
| return _constant_if_small(0.0F, shape, dtype, name); | |||||
| case TF_DataType.TF_INT32: | case TF_DataType.TF_INT32: | ||||
| output = _constant_if_small(0, shape, dtype, name); | |||||
| break; | |||||
| return _constant_if_small(0, shape, dtype, name); | |||||
| default: | default: | ||||
| break; | |||||
| throw new TypeError("can't find type for zeros"); | |||||
| } | } | ||||
| }); | }); | ||||
| return output; | |||||
| } | } | ||||
| private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name) | private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dtype, string name) | ||||
| @@ -31,6 +31,12 @@ TensorFlow 1.13 RC.</PackageReleaseNotes> | |||||
| <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | |||||
| <Compile Remove="runtimes\**" /> | |||||
| <EmbeddedResource Remove="runtimes\**" /> | |||||
| <None Remove="runtimes\**" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <None Remove="Protobuf\README.md" /> | <None Remove="Protobuf\README.md" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -13,28 +13,35 @@ namespace Tensorflow | |||||
| private TF_DataType _dtype; | private TF_DataType _dtype; | ||||
| public string name { get; set; } | public string name { get; set; } | ||||
| public VariableScope() | |||||
| public VariableScope(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
| { | { | ||||
| _reuse = _ReuseMode.AUTO_REUSE; | _reuse = _ReuseMode.AUTO_REUSE; | ||||
| _dtype = dtype; | |||||
| } | } | ||||
| public RefVariable get_variable(_VariableStore var_store, | public RefVariable get_variable(_VariableStore var_store, | ||||
| string name, | string name, | ||||
| TensorShape shape = null, | TensorShape shape = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| IInitializer initializer = null, | |||||
| bool? trainable = null, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
| VariableAggregation aggregation= VariableAggregation.NONE) | VariableAggregation aggregation= VariableAggregation.NONE) | ||||
| { | { | ||||
| string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; | string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; | ||||
| return Python.with<ops.name_scope, Tensor>(new ops.name_scope(""), scope => | |||||
| return Python.with<ops.name_scope, RefVariable>(new ops.name_scope(""), scope => | |||||
| { | { | ||||
| if (dtype == TF_DataType.DtInvalid) | if (dtype == TF_DataType.DtInvalid) | ||||
| dtype = _dtype; | dtype = _dtype; | ||||
| return var_store.get_variable(full_name); | |||||
| return var_store.get_variable(full_name, | |||||
| shape: shape, | |||||
| dtype: dtype, | |||||
| initializer: initializer, | |||||
| trainable: trainable, | |||||
| synchronization: synchronization, | |||||
| aggregation: aggregation); | |||||
| }); | }); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,7 +24,7 @@ namespace Tensorflow | |||||
| TensorShape shape = null, | TensorShape shape = null, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| IInitializer initializer = null, | IInitializer initializer = null, | ||||
| bool trainable = false, | |||||
| bool? trainable = null, | |||||
| bool validate_shape = true, | bool validate_shape = true, | ||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
| VariableAggregation aggregation = VariableAggregation.NONE) | VariableAggregation aggregation = VariableAggregation.NONE) | ||||
| @@ -44,14 +44,23 @@ namespace Tensorflow | |||||
| private RefVariable _true_getter(string name, | private RefVariable _true_getter(string name, | ||||
| TensorShape shape = null, | TensorShape shape = null, | ||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
| IInitializer initializer = null, | IInitializer initializer = null, | ||||
| bool trainable = false, | |||||
| bool? trainable = null, | |||||
| bool validate_shape = true, | bool validate_shape = true, | ||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
| VariableAggregation aggregation = VariableAggregation.NONE) | VariableAggregation aggregation = VariableAggregation.NONE) | ||||
| { | { | ||||
| return _get_single_variable(name: name); | |||||
| bool is_scalar = shape.NDim == 0; | |||||
| return _get_single_variable(name: name, | |||||
| shape: shape, | |||||
| dtype: dtype, | |||||
| initializer: initializer, | |||||
| trainable: trainable, | |||||
| validate_shape: validate_shape, | |||||
| synchronization: synchronization, | |||||
| aggregation: aggregation); | |||||
| } | } | ||||
| private RefVariable _get_single_variable(string name, | private RefVariable _get_single_variable(string name, | ||||
| @@ -59,11 +68,14 @@ namespace Tensorflow | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| IInitializer initializer = null, | IInitializer initializer = null, | ||||
| bool reuse = false, | bool reuse = false, | ||||
| bool trainable = false, | |||||
| bool? trainable = null, | |||||
| bool validate_shape = false, | bool validate_shape = false, | ||||
| bool? use_resource = null, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
| VariableAggregation aggregation = VariableAggregation.NONE) | VariableAggregation aggregation = VariableAggregation.NONE) | ||||
| { | { | ||||
| bool initializing_from_value = false; | |||||
| if (_vars.ContainsKey(name)) | if (_vars.ContainsKey(name)) | ||||
| { | { | ||||
| if (!reuse) | if (!reuse) | ||||
| @@ -74,7 +86,35 @@ namespace Tensorflow | |||||
| throw new NotImplementedException("_get_single_variable"); | throw new NotImplementedException("_get_single_variable"); | ||||
| } | } | ||||
| throw new NotImplementedException("_get_single_variable"); | |||||
| Tensor init_val = null; | |||||
| ops.init_scope(); | |||||
| { | |||||
| if (initializing_from_value) | |||||
| { | |||||
| } | |||||
| else | |||||
| { | |||||
| init_val = initializer.call(shape, dtype); | |||||
| var variable_dtype = dtype.as_base_dtype(); | |||||
| } | |||||
| } | |||||
| // Create the variable. | |||||
| if (use_resource == null) | |||||
| use_resource = false; | |||||
| var v = variable_scope.default_variable_creator(init_val, | |||||
| name: name, | |||||
| trainable: trainable, | |||||
| dtype: TF_DataType.DtInvalid, | |||||
| validate_shape: validate_shape, | |||||
| synchronization: synchronization, | |||||
| aggregation: aggregation); | |||||
| _vars[name] = v; | |||||
| return v; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -12,14 +12,22 @@ namespace Tensorflow | |||||
| return variables.variables_initializer(g.ToArray()); | return variables.variables_initializer(g.ToArray()); | ||||
| } | } | ||||
| public static RefVariable get_variable(string name, | |||||
| TensorShape shape = null, | |||||
| public static RefVariable get_variable(string name, | |||||
| TensorShape shape = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| IInitializer initializer = null, | IInitializer initializer = null, | ||||
| bool? trainable = null, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
| VariableAggregation aggregation = VariableAggregation.NONE) | VariableAggregation aggregation = VariableAggregation.NONE) | ||||
| { | { | ||||
| var scope = variable_scope.get_variable_scope(); | |||||
| var store = variable_scope._get_default_variable_store(); | var store = variable_scope._get_default_variable_store(); | ||||
| return variable_scope.get_variable_scope().get_variable(store, name, shape: shape); | |||||
| return scope.get_variable(store, | |||||
| name, | |||||
| shape: shape, | |||||
| dtype: dtype, | |||||
| initializer: initializer, | |||||
| trainable: trainable); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,9 +10,16 @@ namespace Tensorflow | |||||
| public static string _VARSCOPESTORE_KEY = "__varscope"; | public static string _VARSCOPESTORE_KEY = "__varscope"; | ||||
| public static bool _DEFAULT_USE_RESOURCE = false; | public static bool _DEFAULT_USE_RESOURCE = false; | ||||
| public static RefVariable default_variable_creator(object initial_value, string name = "", TF_DataType dtype = TF_DataType.DtInvalid, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO) | |||||
| public static RefVariable default_variable_creator(object initial_value, | |||||
| string name = "", | |||||
| bool? trainable = null, | |||||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||||
| bool validate_shape = false, | |||||
| bool ? use_resource = null, | |||||
| VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
| VariableAggregation aggregation = VariableAggregation.NONE) | |||||
| { | { | ||||
| var trainable = _get_trainable_value(synchronization); | |||||
| trainable = _get_trainable_value(synchronization, trainable); | |||||
| if (!use_resource.HasValue) | if (!use_resource.HasValue) | ||||
| { | { | ||||
| use_resource = get_variable_scope().use_resource; | use_resource = get_variable_scope().use_resource; | ||||
| @@ -77,18 +84,22 @@ namespace Tensorflow | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true) | |||||
| public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = true) | |||||
| { | { | ||||
| if (synchronization == VariableSynchronization.ON_READ) | if (synchronization == VariableSynchronization.ON_READ) | ||||
| { | { | ||||
| if (trainable) | |||||
| if (trainable.Value) | |||||
| throw new ValueError("Synchronization value can be set to " + | throw new ValueError("Synchronization value can be set to " + | ||||
| "VariableSynchronization.ON_READ only for non-trainable variables. " + | "VariableSynchronization.ON_READ only for non-trainable variables. " + | ||||
| "You have specified trainable=True and " + | "You have specified trainable=True and " + | ||||
| "synchronization=VariableSynchronization.ON_READ."); | "synchronization=VariableSynchronization.ON_READ."); | ||||
| } | } | ||||
| else if (!trainable.HasValue) | |||||
| { | |||||
| trainable = true; | |||||
| } | |||||
| return trainable; | |||||
| return trainable.Value; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||