Browse Source

_safe_initial_value_from_op #360

tags/v0.12
Oceania2018 6 years ago
parent
commit
23edfd09f8
4 changed files with 17 additions and 8 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Train/moving_averages.cs
  4. +8
    -4
      src/TensorFlowNET.Core/Variables/RefVariable.cs

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -242,7 +242,7 @@ namespace Tensorflow
throw new RuntimeError("Graph is finalized and cannot be modified.");
}

public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = null,
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
{


+ 5
- 0
src/TensorFlowNET.Core/Train/ExponentialMovingAverage.cs View File

@@ -49,6 +49,11 @@ namespace Tensorflow.Train
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var);
_averages[var] = avg;
}
else
{
// avg = slot_creator.create_zeros_slot(
throw new NotImplementedException("");
}
}

return tf_with(ops.name_scope(name), scope =>


+ 3
- 3
src/TensorFlowNET.Core/Train/moving_averages.cs View File

@@ -19,14 +19,14 @@ namespace Tensorflow.Train
public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay,
bool zero_debias = true, string name = null)
{
tf_with(ops.name_scope(name, "", new { variable, value, decay }), scope =>
return tf_with(ops.name_scope(name, "AssignMovingAvg", new { variable, value, decay }), scope =>
{
decay = ops.convert_to_tensor(1.0f - decay, name: "decay");
if (decay.dtype != variable.dtype.as_base_dtype())
decay = math_ops.cast(decay, variable.dtype.as_base_dtype());
});

throw new NotImplementedException("assign_moving_average");
return decay;
});
}
}
}

+ 8
- 4
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -273,10 +273,14 @@ namespace Tensorflow
new_op_type = "Switch";
var new_op_name = op.node_def.Name + "_" + name;
new_op_name = new_op_name.Replace(":", "_");
var attrs = new Dictionary<string, AttrValue>();
attrs[op.node_def.Name] = op.node_def.Attr.ElementAt(0).Value;
/*return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types,
name: new_op_name, attrs: attrs);*/

// Convert attr values to AttrValue protos.
var attr_protos = new Dictionary<string, AttrValue>();
foreach (var attr_def in op.node_def.Attr)
attr_protos[attr_def.Key] = attr_def.Value;

return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types,
name: new_op_name, attrs: attr_protos);
}
return op;
}


Loading…
Cancel
Save