| @@ -242,7 +242,7 @@ namespace Tensorflow | |||||
| throw new RuntimeError("Graph is finalized and cannot be modified."); | throw new RuntimeError("Graph is finalized and cannot be modified."); | ||||
| } | } | ||||
| public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | |||||
| public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | |||||
| TF_DataType[] input_types = null, string name = null, | TF_DataType[] input_types = null, string name = null, | ||||
| Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | ||||
| { | { | ||||
| @@ -49,6 +49,11 @@ namespace Tensorflow.Train | |||||
| ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var); | ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var); | ||||
| _averages[var] = avg; | _averages[var] = avg; | ||||
| } | } | ||||
| else | |||||
| { | |||||
| // avg = slot_creator.create_zeros_slot( | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| } | } | ||||
| return tf_with(ops.name_scope(name), scope => | return tf_with(ops.name_scope(name), scope => | ||||
| @@ -19,14 +19,14 @@ namespace Tensorflow.Train | |||||
| public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay, | public static Tensor assign_moving_average(RefVariable variable, RefVariable value, Tensor decay, | ||||
| bool zero_debias = true, string name = null) | bool zero_debias = true, string name = null) | ||||
| { | { | ||||
| tf_with(ops.name_scope(name, "", new { variable, value, decay }), scope => | |||||
| return tf_with(ops.name_scope(name, "AssignMovingAvg", new { variable, value, decay }), scope => | |||||
| { | { | ||||
| decay = ops.convert_to_tensor(1.0f - decay, name: "decay"); | decay = ops.convert_to_tensor(1.0f - decay, name: "decay"); | ||||
| if (decay.dtype != variable.dtype.as_base_dtype()) | if (decay.dtype != variable.dtype.as_base_dtype()) | ||||
| decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); | decay = math_ops.cast(decay, variable.dtype.as_base_dtype()); | ||||
| }); | |||||
| throw new NotImplementedException("assign_moving_average"); | |||||
| return decay; | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -273,10 +273,14 @@ namespace Tensorflow | |||||
| new_op_type = "Switch"; | new_op_type = "Switch"; | ||||
| var new_op_name = op.node_def.Name + "_" + name; | var new_op_name = op.node_def.Name + "_" + name; | ||||
| new_op_name = new_op_name.Replace(":", "_"); | new_op_name = new_op_name.Replace(":", "_"); | ||||
| var attrs = new Dictionary<string, AttrValue>(); | |||||
| attrs[op.node_def.Name] = op.node_def.Attr.ElementAt(0).Value; | |||||
| /*return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types, | |||||
| name: new_op_name, attrs: attrs);*/ | |||||
| // Convert attr values to AttrValue protos. | |||||
| var attr_protos = new Dictionary<string, AttrValue>(); | |||||
| foreach (var attr_def in op.node_def.Attr) | |||||
| attr_protos[attr_def.Key] = attr_def.Value; | |||||
| return op.graph.create_op(new_op_type, new_op_inputs.ToArray(), op._output_types, | |||||
| name: new_op_name, attrs: attr_protos); | |||||
| } | } | ||||
| return op; | return op; | ||||
| } | } | ||||