You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.py.cs 13 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
6 years ago
7 years ago
6 years ago
7 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Runtime.InteropServices;
  4. using System.Text;
  5. using System.Threading;
  6. using Tensorflow;
  7. using node_def_pb2 = Tensorflow;
  8. using Google.Protobuf;
  9. using System.Linq;
  10. using NumSharp.Core;
  11. using System.ComponentModel;
  12. namespace Tensorflow
  13. {
  14. public partial class ops
  15. {
  16. public static void add_to_collection<T>(string name, T value)
  17. {
  18. var graph = tf.get_default_graph();
  19. graph.add_to_collection(name, value);
  20. }
  21. public static void add_to_collections<T>(List<string> names, T value)
  22. {
  23. var graph = tf.get_default_graph();
  24. graph.add_to_collections(names, value);
  25. }
  26. /// <summary>
  27. /// Wrapper for `Graph.get_collection()` using the default graph.
  28. /// contains many standard names for collections.
  29. /// </summary>
  30. /// <param name="key">
  31. /// The key for the collection. For example, the `GraphKeys` class
  32. /// </param>
  33. /// <param name="scope"></param>
  34. /// <returns>
  35. /// The list of values in the collection with the given `name`, or
  36. /// an empty list if no value has been added to that collection. The
  37. /// list contains the values in the order under which they were
  38. /// collected.
  39. /// </returns>
  40. public static object get_collection(string key, string scope = "")
  41. {
  42. return get_default_graph().get_collection(key, scope);
  43. }
  44. public static Graph get_default_graph()
  45. {
  46. return tf.Graph();
  47. }
  48. public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
  49. {
  50. foreach(var op_input in op_input_list)
  51. {
  52. // Determine if this is a valid graph_element.
  53. var graph_element = op_input;
  54. }
  55. return get_default_graph();
  56. }
  57. /// <summary>
  58. /// Converts the given `value` to a `Tensor`.
  59. /// </summary>
  60. /// <param name="value"></param>
  61. /// <param name="dtype"></param>
  62. /// <param name="name"></param>
  63. /// <returns></returns>
  64. public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
  65. {
  66. switch (value)
  67. {
  68. case Tensor val:
  69. return val;
  70. default:
  71. var nd = tensor_util.convert_to_numpy_ndarray(value);
  72. return constant_op.constant(nd, name);
  73. }
  74. }
  75. /// <summary>
  76. /// Wrapper for `Graph.control_dependencies()` using the default graph.
  77. /// </summary>
  78. /// <param name="control_inputs"></param>
  79. public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
  80. {
  81. return get_default_graph().control_dependencies(control_inputs);
  82. }
  83. /// <summary>
  84. /// Creates a TF_Operation.
  85. /// </summary>
  86. /// <param name="graph">a `Graph`.</param>
  87. /// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param>
  88. /// <param name="inputs">
  89. /// A list of `Tensor`s (corresponding to scalar inputs) and lists of
  90. /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
  91. /// "list(int64)"). The length of the list should be equal to the number of
  92. /// inputs specified by this operation's op def.
  93. /// </param>
  94. /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
  95. /// <returns>A wrapped TF_Operation*.</returns>
  96. public static IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs, Operation[] control_inputs)
  97. {
  98. var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
  99. // Add inputs
  100. if(inputs != null)
  101. {
  102. foreach (var op_input in inputs)
  103. {
  104. bool isList = false;
  105. if (!isList)
  106. {
  107. c_api.TF_AddInput(op_desc, op_input._as_tf_output());
  108. }
  109. else
  110. {
  111. c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count);
  112. }
  113. }
  114. }
  115. var status = new Status();
  116. // Add control inputs
  117. foreach (var control_input in control_inputs)
  118. c_api.TF_AddControlInput(op_desc, control_input);
  119. // Add attrs
  120. foreach (var attr in node_def.Attr)
  121. {
  122. var bytes = attr.Value.ToByteArray();
  123. var proto = Marshal.AllocHGlobal(bytes.Length);
  124. Marshal.Copy(bytes, 0, proto, bytes.Length);
  125. c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (uint)bytes.Length, status: status);
  126. status.Check(true);
  127. }
  128. var c_op = c_api.TF_FinishOperation(op_desc, status);
  129. status.Check(true);
  130. return c_op;
  131. }
  132. public static OpDef _get_op_def(Graph graph, string type)
  133. {
  134. return graph.GetOpDef(type);
  135. }
  136. public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
  137. {
  138. var node_def = new node_def_pb2.NodeDef();
  139. node_def.Op = op_type;
  140. node_def.Name = name;
  141. foreach (var attr in attrs)
  142. {
  143. node_def.Attr.Add(attr.Key, attr.Value);
  144. }
  145. return node_def;
  146. }
  147. public static string _name_from_scope_name(string name)
  148. {
  149. if (name.EndsWith("/"))
  150. {
  151. return name.Substring(0, name.Length - 1);
  152. }
  153. else
  154. {
  155. return name;
  156. }
  157. }
  158. /// <summary>
  159. /// A context manager that lifts ops out of control-flow scopes and function-building graphs.
  160. /// </summary>
  161. /// <returns></returns>
  162. public static void init_scope()
  163. {
  164. // Retrieve the active name scope: entering an `init_scope` preserves
  165. // the name scope of the current context.
  166. var default_graph = get_default_graph();
  167. var scope = default_graph.get_name_scope();
  168. if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
  169. // Names that end with trailing slashes are treated by `name_scope` as
  170. // absolute.
  171. scope += "/";
  172. // inner_device_stack = default_graph._device_function_stack
  173. // var outer_context = default_graph.as_default;
  174. Python.with(ops.control_dependencies(null), delegate
  175. {
  176. var outer_graph = get_default_graph();
  177. // outer_device_stack = None
  178. });
  179. }
  180. private static int uid_number = 0;
  181. /// <summary>
  182. /// A unique (within this program execution) integer.
  183. /// Not thread safe
  184. /// </summary>
  185. /// <returns></returns>
  186. public static int uid()
  187. {
  188. return uid_number++;
  189. }
  190. public static void colocate_with(Operation op, bool ignore_existing = false)
  191. {
  192. _colocate_with_for_gradient(op, null, ignore_existing);
  193. }
  194. public static void colocate_with(Tensor tensor, bool ignore_existing = false)
  195. {
  196. _colocate_with_for_gradient(tensor.op, null, ignore_existing);
  197. }
  198. private static void _colocate_with_for_gradient(Operation op, int? gradient_uid, bool ignore_existing = false)
  199. {
  200. var default_graph = get_default_graph();
  201. default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing);
  202. }
  203. /// <summary>
  204. /// Uses the default session to evaluate one or more tensors.
  205. /// </summary>
  206. /// <param name="tensors">A single Tensor, or a list of Tensor objects.</param>
  207. /// <param name="feed_dict">
  208. /// A dictionary that maps Tensor objects (or tensor names) to lists,
  209. /// numpy ndarrays, TensorProtos, or strings.
  210. /// </param>
  211. /// <param name="graph">The graph in which the tensors are defined.</param>
  212. /// <param name="session">A different session to use to evaluate "tensors".</param>
  213. /// <returns>
  214. /// Either a single numpy ndarray if "tensors" is a single tensor; or a list
  215. /// of numpy ndarrays that each correspond to the respective element in
  216. /// "tensors".
  217. /// </returns>
  218. public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed_dict, Graph graph, Session session = null)
  219. {
  220. if (session == null)
  221. {
  222. session = get_default_session();
  223. if (session == null)
  224. throw new ValueError("Cannot evaluate tensor using `eval()`: No default " +
  225. "session is registered. Use `with " +
  226. "sess.as_default()` or pass an explicit session to " +
  227. "`eval(session=sess)`");
  228. if (session.graph != graph)
  229. throw new ValueError("Cannot use the default session to evaluate tensor: " +
  230. "the tensor's graph is different from the session's " +
  231. "graph. Pass an explicit session to " +
  232. "`eval(session=sess)`.");
  233. }
  234. else
  235. {
  236. if (session.graph != graph)
  237. throw new ValueError("Cannot use the default session to evaluate tensor: " +
  238. "the tensor's graph is different from the session's " +
  239. "graph. Pass an explicit session to " +
  240. "`eval(session=sess)`.");
  241. }
  242. return session.run(tensor, feed_dict);
  243. }
  244. /// <summary>
  245. /// Returns the default session for the current thread.
  246. /// </summary>
  247. /// <returns>The default `Session` being used in the current thread.</returns>
  248. public static Session get_default_session()
  249. {
  250. return tf.Session();
  251. }
  252. public static Func<Operation, Tensor, (Tensor, Tensor)> get_gradient_function(Operation op)
  253. {
  254. if (op.inputs == null) return null;
  255. return (oper, out_grads) =>
  256. {
  257. switch (oper.type)
  258. {
  259. case "Add":
  260. return math_grad._AddGrad(oper, out_grads);
  261. case "Sum":
  262. return math_grad._SumGrad(oper, out_grads);
  263. case "RealDiv":
  264. return math_grad._RealDivGrad(oper, out_grads);
  265. default:
  266. throw new NotImplementedException($"get_gradient_function {oper.type}");
  267. }
  268. /*var result = typeof(math_grad).GetMethod($"_{op.type}Grad").Invoke(null, new object[] { op, out_grads });
  269. var p1 = result.GetType().GetProperty("Item1");
  270. var p2 = result.GetType().GetProperty("Item2");
  271. return (p1.GetValue(result, null) as Tensor, p2.GetValue(result, null) as Tensor);*/
  272. };
  273. }
  274. public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
  275. string name = "", DataType preferred_dtype = DataType.DtInvalid,
  276. bool as_ref = false)
  277. {
  278. var ret = new List<Tensor>();
  279. foreach((int i, T value) in Python.enumerate(values))
  280. {
  281. string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
  282. ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
  283. }
  284. return ret.ToArray();
  285. }
  286. public static Tensor internal_convert_to_tensor<T>(T value, DataType dtype = DataType.DtInvalid,
  287. string name = "", DataType preferred_dtype = DataType.DtInvalid,
  288. bool as_ref = false)
  289. {
  290. switch (typeof(T).Name)
  291. {
  292. case "Tensor":
  293. return value as Tensor;
  294. default:
  295. return constant_op.constant(np.array(value), name);
  296. }
  297. }
  298. }
  299. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。