| @@ -24,7 +24,7 @@ namespace Tensorflow | |||||
| _ops[op_def.Name] = op_def; | _ops[op_def.Name] = op_def; | ||||
| } | } | ||||
| public unsafe Operation _apply_op_helper(string op_type_name, string name = "", DataType? dtype = null, TensorShape shape = null) | |||||
| public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||||
| { | { | ||||
| var op_def = _ops[op_type_name]; | var op_def = _ops[op_type_name]; | ||||
| @@ -46,9 +46,30 @@ namespace Tensorflow | |||||
| var key = attr_def.Name; | var key = attr_def.Name; | ||||
| } | } | ||||
| foreach(var input_arg in op_def.InputArg) | |||||
| var attrs = new Dictionary<string, object>(); | |||||
| var inputs = new List<Tensor>(); | |||||
| var input_types = new List<DataType>(); | |||||
| foreach (var attr in op_def.Attr) | |||||
| { | |||||
| if (keywords.ContainsKey(attr.Name)) | |||||
| { | |||||
| attrs[attr.Name] = keywords[attr.Name]; | |||||
| } | |||||
| } | |||||
| foreach (var input_arg in op_def.InputArg) | |||||
| { | { | ||||
| var input_name = input_arg.Name; | |||||
| if (keywords.ContainsKey(input_name)) | |||||
| { | |||||
| inputs.Add(keywords[input_name] as Tensor); | |||||
| } | |||||
| if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||||
| { | |||||
| attrs[input_arg.TypeAttr] = DataType.DtFloat; | |||||
| } | |||||
| } | } | ||||
| var attr_protos = new Dictionary<string, AttrValue>(); | var attr_protos = new Dictionary<string, AttrValue>(); | ||||
| @@ -60,7 +81,7 @@ namespace Tensorflow | |||||
| switch (attr_def.Type) | switch (attr_def.Type) | ||||
| { | { | ||||
| case "type": | case "type": | ||||
| attr_value.Type = dtype.Value; | |||||
| attr_value.Type = (DataType)keywords["dtype"]; | |||||
| break; | break; | ||||
| case "shape": | case "shape": | ||||
| attr_value.Shape = new TensorShapeProto(); | attr_value.Shape = new TensorShapeProto(); | ||||
| @@ -84,9 +105,9 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| var op = g.create_op(op_type_name, null, output_types.ToArray(), | |||||
| var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||||
| name: scope, | name: scope, | ||||
| input_types: new DataType[] { }, | |||||
| input_types: input_types.ToArray(), | |||||
| attrs: attr_protos, | attrs: attr_protos, | ||||
| op_def: op_def); | op_def: op_def); | ||||
| @@ -27,7 +27,10 @@ | |||||
| <None Update="tensorflow.dll"> | <None Update="tensorflow.dll"> | ||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
| </None> | </None> | ||||
| <None Update="Tensorflow\op_list_proto_bytes.bin"> | |||||
| <None Update="Tensorflow\op_list_proto_array.bin"> | |||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
| </None> | |||||
| <None Update="Tensorflow\op_list_proto_math.bin"> | |||||
| <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
| </None> | </None> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -8,11 +8,22 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static class gen_array_ops | public static class gen_array_ops | ||||
| { | { | ||||
| public static OpDefLibrary _op_def_lib => _InitOpDefLibrary(); | |||||
| public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); | |||||
| public static Tensor placeholder(DataType dtype, TensorShape shape = null) | public static Tensor placeholder(DataType dtype, TensorShape shape = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("Placeholder", dtype: dtype, shape: shape); | |||||
| /*var g = ops.get_default_graph(); | |||||
| var op = new Operation(g, "Placeholder", "feed"); | |||||
| var tensor = new Tensor(op, 0, dtype); | |||||
| return tensor;*/ | |||||
| var keywords = new Dictionary<string, object>(); | |||||
| keywords.Add("dtype", dtype); | |||||
| keywords.Add("shape", shape); | |||||
| var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); | |||||
| var _result = _op.outputs; | var _result = _op.outputs; | ||||
| var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
| var _attrs = new Dictionary<string, object>(); | var _attrs = new Dictionary<string, object>(); | ||||
| @@ -27,7 +38,7 @@ namespace Tensorflow | |||||
| private static OpDefLibrary _InitOpDefLibrary() | private static OpDefLibrary _InitOpDefLibrary() | ||||
| { | { | ||||
| // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | ||||
| var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_bytes.bin"); | |||||
| var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin"); | |||||
| var op_list = OpList.Parser.ParseFrom(bytes); | var op_list = OpList.Parser.ParseFrom(bytes); | ||||
| var op_def_lib = new OpDefLibrary(); | var op_def_lib = new OpDefLibrary(); | ||||
| op_def_lib.add_op_list(op_list); | op_def_lib.add_op_list(op_list); | ||||
| @@ -1,15 +1,34 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | |||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public static class gen_math_ops | public static class gen_math_ops | ||||
| { | { | ||||
| public static Tensor add(Tensor a, Tensor b, string name = "") | |||||
| public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); | |||||
| public static Tensor add(Tensor a, Tensor b) | |||||
| { | { | ||||
| var keywords = new Dictionary<string, object>(); | |||||
| keywords.Add("x", a); | |||||
| keywords.Add("y", b); | |||||
| var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); | |||||
| return null; | return null; | ||||
| } | } | ||||
| private static OpDefLibrary _InitOpDefLibrary() | |||||
| { | |||||
| // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | |||||
| var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_math.bin"); | |||||
| var op_list = OpList.Parser.ParseFrom(bytes); | |||||
| var op_def_lib = new OpDefLibrary(); | |||||
| op_def_lib.add_op_list(op_list); | |||||
| return op_def_lib; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -20,18 +20,11 @@ namespace Tensorflow | |||||
| public static unsafe Tensor add(Tensor a, Tensor b) | public static unsafe Tensor add(Tensor a, Tensor b) | ||||
| { | { | ||||
| return null; | |||||
| return gen_math_ops.add(a, b); | |||||
| } | } | ||||
| public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null) | public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null) | ||||
| { | { | ||||
| /*var g = ops.get_default_graph(); | |||||
| var op = new Operation(g, "Placeholder", "feed"); | |||||
| var tensor = new Tensor(op, 0, dtype); | |||||
| return tensor;*/ | |||||
| return gen_array_ops.placeholder(dtype, shape); | return gen_array_ops.placeholder(dtype, shape); | ||||
| } | } | ||||