Browse Source

_safe_initial_value_from_op #360

tags/v0.12
Oceania2018 6 years ago
parent
commit
23edfd09f8
4 changed files with 17 additions and 8 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Train/moving_averages.cs
  4. +8
    -4
      src/TensorFlowNET.Core/Variables/RefVariable.cs

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -242,7 +242,7 @@ namespace Tensorflow
throw new RuntimeError("Graph is finalized and cannot be modified."); throw new RuntimeError("Graph is finalized and cannot be modified.");
} }


public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = null, TF_DataType[] input_types = null, string name = null,
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
{ {


+ 5
- 0
src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs View File

@@ -49,6 +49,11 @@ namespace Tensorflow.Train
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var); ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var);
_averages[var] = avg; _averages[var] = avg;
} }
else
{
// avg = slot_creator.create_zeros_slot(
throw new NotImplementedException("");
}
} }


return tf_with(ops.name_scope(name), scope => return tf_with(ops.name_scope(name), scope =>


+ 3
- 3
src/TensorFlowNET.Core/Train/moving_averages.cs View File

@@ -19,14 +19,14 @@ namespace Tensorflow.Train
public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay, public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay,
bool zero_debias = true, string name = null) bool zero_debias = true, string name = null)
{ {
tf_with(ops.name_scope(name, "", new { variable, value, decay }), scope =>
return tf_with(ops.name_scope(name, "AssignMovingAvg", new { variable, value, decay }), scope =>
{ {
decay = ops.convert_to_tensor(1.0f - decay, name: "decay"); decay = ops.convert_to_tensor(1.0f - decay, name: "decay");
if (decay.dtype != variable.dtype.as_base_dtype()) if (decay.dtype != variable.dtype.as_base_dtype())
decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); decay = math_ops.cast(decay, variable.dtype.as_base_dtype());
});


throw new NotImplementedException("assign_moving_average");
return decay;
});
} }
} }
} }

+ 8
- 4
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -273,10 +273,14 @@ namespace Tensorflow
new_op_type = "Switch"; new_op_type = "Switch";
var new_op_name = op.node_def.Name + "_" + name; var new_op_name = op.node_def.Name + "_" + name;
new_op_name = new_op_name.Replace(":", "_"); new_op_name = new_op_name.Replace(":", "_");
var attrs = new Dictionary<string, AttrValue>();
attrs[op.node_def.Name] = op.node_def.Attr.ElementAt(0).Value;
/*return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types,
name: new_op_name, attrs: attrs);*/

// Convert attr values to AttrValue protos.
var attr_protos = new Dictionary<string, AttrValue>();
foreach (var attr_def in op.node_def.Attr)
attr_protos[attr_def.Key] = attr_def.Value;

return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types,
name: new_op_name, attrs: attr_protos);
} }
return op; return op;
} }


Loading…
Cancel
Save