| @@ -9,5 +9,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| internal GradientTape _tapeSet; | |||||
| GradientTape _tapeSet; | |||||
| /// <summary> | /// <summary> | ||||
| /// Record operations for automatic differentiation. | /// Record operations for automatic differentiation. | ||||
| @@ -79,42 +79,6 @@ namespace Tensorflow | |||||
| return _GatherReturnElements(return_elements, graph, results); | return _GatherReturnElements(return_elements, graph, results); | ||||
| } | } | ||||
| //private static ITensorOrOperation[] _import_graph_def_internal(GraphDef graph_def, Dictionary<string, Tensor> input_map = null, string[] return_elements = null, | |||||
| // bool validate_colocation_constraints = true, string name = null, OpList producer_op_list = null) | |||||
| //{ | |||||
| // graph_def = _ProcessGraphDefParam(graph_def); | |||||
| // input_map = _ProcessInputMapParam(input_map); | |||||
| // return_elements = _ProcessReturnElementsParam(return_elements); | |||||
| // if(producer_op_list is not null) | |||||
| // { | |||||
| // _RemoveDefaultAttrs(producer_op_list, graph_def); | |||||
| // } | |||||
| // var graph = ops.get_default_graph(); | |||||
| // string prefix = null; | |||||
| // tf_with(ops.name_scope(name, "import", input_map.Values), scope => | |||||
| // { | |||||
| // if (scope is not null) | |||||
| // { | |||||
| // Debug.Assert(scope.scope_name.EndsWith("/")); | |||||
| // prefix = scope.scope_name[scope.scope_name.Length - 1].ToString(); | |||||
| // } | |||||
| // else | |||||
| // { | |||||
| // prefix = ""; | |||||
| // } | |||||
| // input_map = _ConvertInputMapValues(name, input_map); | |||||
| // }); | |||||
| // var scope_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||||
| // var options = scope_options.Options; | |||||
| // _PopulateTFImportGraphDefOptions(scope_options, prefix, input_map, return_elements, validate_colocation_constraints); | |||||
| //} | |||||
| private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, | private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, | ||||
| Graph graph, | Graph graph, | ||||
| TF_ImportGraphDefResults results) | TF_ImportGraphDefResults results) | ||||
| @@ -305,7 +305,7 @@ namespace Tensorflow.Functions | |||||
| private Tensors _build_call_outputs(Tensors result) | private Tensors _build_call_outputs(Tensors result) | ||||
| { | { | ||||
| // TODO(Rinne): dwal with `func_graph.structured_outputs` | |||||
| // TODO(Rinne): deal with `func_graph.structured_outputs` | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -4,7 +4,7 @@ using Tensorflow.Train; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class Function: Trackable | |||||
| public class Function: Trackable, IGenericFunction | |||||
| { | { | ||||
| #pragma warning disable CS0169 // The field 'Function._handle' is never used | #pragma warning disable CS0169 // The field 'Function._handle' is never used | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| @@ -34,6 +34,11 @@ namespace Tensorflow | |||||
| return result; | return result; | ||||
| } | } | ||||
| public ConcreteFunction get_concrete_function(params Tensor[] args) | |||||
| { | |||||
| return _get_concrete_function_garbage_collected(args); | |||||
| } | |||||
| protected virtual Tensors _call(Tensors inputs) | protected virtual Tensors _call(Tensors inputs) | ||||
| { | { | ||||
| if(_variable_creation_fn is not null) | if(_variable_creation_fn is not null) | ||||
| @@ -57,6 +62,18 @@ namespace Tensorflow | |||||
| return false; | return false; | ||||
| } | } | ||||
| protected ConcreteFunction _get_concrete_function_garbage_collected(Tensor[] args) | |||||
| { | |||||
| if(_variable_creation_fn is null) | |||||
| { | |||||
| _initialize(args); | |||||
| // TODO(Rinne): _initialize_uninitialized_variables | |||||
| } | |||||
| var concrete = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||||
| return concrete; | |||||
| } | |||||
| private void _initialize(Tensor[] args) | private void _initialize(Tensor[] args) | ||||
| { | { | ||||
| _variable_creation_fn = _compiler(_csharp_function); | _variable_creation_fn = _compiler(_csharp_function); | ||||
| @@ -6,7 +6,7 @@ namespace Tensorflow.Functions | |||||
| { | { | ||||
| public interface IGenericFunction | public interface IGenericFunction | ||||
| { | { | ||||
| object[] Apply(params object[] args); | |||||
| ConcreteFunction get_concrete_function(params object[] args); | |||||
| Tensors Apply(Tensors args); | |||||
| ConcreteFunction get_concrete_function(params Tensor[] args); | |||||
| } | } | ||||
| } | } | ||||
| @@ -49,7 +49,7 @@ namespace Tensorflow.Functions | |||||
| private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) | private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) | ||||
| { | { | ||||
| var lookup_func_key = male_cache_key(args); | |||||
| var lookup_func_key = make_cache_key(args); | |||||
| if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) | if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) | ||||
| { | { | ||||
| return (concrete_function, args); | return (concrete_function, args); | ||||
| @@ -71,7 +71,7 @@ namespace Tensorflow.Functions | |||||
| return concrete_function; | return concrete_function; | ||||
| } | } | ||||
| private static string male_cache_key(Tensor[] inputs) | |||||
| private static string make_cache_key(Tensor[] inputs) | |||||
| { | { | ||||
| //string res = ""; | //string res = ""; | ||||
| //foreach (var input in inputs) | //foreach (var input in inputs) | ||||
| @@ -727,7 +727,7 @@ namespace Tensorflow | |||||
| private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func<Operation, Tensor[], Tensor[]> grad_fn) | ||||
| { | { | ||||
| //scope = scope.TrimEnd('/').Replace('/', '_'); | |||||
| // scope = scope.TrimEnd('/').Replace('/', '_'); | |||||
| return grad_fn(op, out_grads); | return grad_fn(op, out_grads); | ||||
| } | } | ||||
| @@ -38,21 +38,6 @@ namespace Tensorflow.Graphs | |||||
| // make function as an Operation by autograph | // make function as an Operation by autograph | ||||
| // need to restore mode when exits | // need to restore mode when exits | ||||
| //var func_graph = new FuncGraph(func_name); | |||||
| //func_graph.as_default(); | |||||
| //var input_placeholders = args.Arguments.Select(x => tf.placeholder(((Tensor)x).dtype)).ToArray(); | |||||
| //// stop the function from recursive call. | |||||
| //already_in_boundary = true; | |||||
| //var outputs = args.Method.Invoke(args.Instance, input_placeholders) as Tensors; | |||||
| //already_in_boundary = false; | |||||
| //var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
| //func_graph.ToGraph(opers, | |||||
| // input_placeholders, | |||||
| // outputs, | |||||
| // null); | |||||
| //func_graph.Exit(); | |||||
| function = new ConcreteFunction(func_name); | function = new ConcreteFunction(func_name); | ||||
| function.Enter(); | function.Enter(); | ||||
| @@ -208,9 +208,5 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||||
| } | } | ||||
| } | } | ||||
| @@ -10060,7 +10060,7 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| catch (Exception) | catch (Exception) | ||||
| { | { | ||||
| Console.WriteLine(); | |||||
| } | } | ||||
| try | try | ||||
| { | { | ||||
| @@ -10068,7 +10068,7 @@ namespace Tensorflow.Operations | |||||
| } | } | ||||
| catch (Exception) | catch (Exception) | ||||
| { | { | ||||
| Console.WriteLine(); | |||||
| } | } | ||||
| } | } | ||||