| @@ -17,8 +17,8 @@ namespace Tensorflow.Graphs | |||||
| IntPtr func_handle; | IntPtr func_handle; | ||||
| public string FuncName => _graph_key; | public string FuncName => _graph_key; | ||||
| public Tensors Inputs { get; set; } | |||||
| public Tensors Outputs { get; set; } | |||||
| public Tensors Inputs { get; set; } = new Tensors(); | |||||
| public Tensors Outputs { get; set; } = new Tensors(); | |||||
| public Dictionary<string, string> Attrs { get; set; } | public Dictionary<string, string> Attrs { get; set; } | ||||
| public Dictionary<long, (Tensor, Tensor)> _captures | public Dictionary<long, (Tensor, Tensor)> _captures | ||||
| @@ -175,14 +175,7 @@ namespace Tensorflow.Graphs | |||||
| void add_capture(Tensor tensor, Tensor placeholder) | void add_capture(Tensor tensor, Tensor placeholder) | ||||
| { | { | ||||
| _captures.Add(tensor.Id, (tensor, placeholder)); | _captures.Add(tensor.Id, (tensor, placeholder)); | ||||
| if (Inputs == null) | |||||
| Inputs = new Tensors(placeholder); | |||||
| else | |||||
| { | |||||
| var inputs = Inputs.ToList(); | |||||
| inputs.Add(placeholder); | |||||
| Inputs = new Tensors(inputs.ToArray()); | |||||
| } | |||||
| Inputs.Add(placeholder); | |||||
| } | } | ||||
| Tensor _create_substitute_placeholder(Tensor value, | Tensor _create_substitute_placeholder(Tensor value, | ||||
| @@ -39,7 +39,8 @@ namespace Tensorflow | |||||
| public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | ||||
| { | { | ||||
| _graph = g ?? ops.get_default_graph(); | _graph = g ?? ops.get_default_graph(); | ||||
| _graph.as_default(); | |||||
| if (!_graph.building_function) | |||||
| _graph.as_default(); | |||||
| _target = Encoding.UTF8.GetBytes(target); | _target = Encoding.UTF8.GetBytes(target); | ||||
| using (var opts = new SessionOptions(target, config)) | using (var opts = new SessionOptions(target, config)) | ||||
| @@ -58,9 +58,6 @@ namespace Tensorflow.Keras.Layers | |||||
| args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; | args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; | ||||
| } | } | ||||
| // In graph mode, create a graph placeholder to call the layer on. | |||||
| tf.Context.graph_mode(); | |||||
| if (args.InputTensor == null) | if (args.InputTensor == null) | ||||
| { | { | ||||
| if (args.InputShape != null) | if (args.InputShape != null) | ||||
| @@ -74,6 +71,9 @@ namespace Tensorflow.Keras.Layers | |||||
| args.BatchInputShape = null; | args.BatchInputShape = null; | ||||
| } | } | ||||
| var graph = keras.backend.get_graph(); | |||||
| graph.as_default(); | |||||
| args.InputTensor = keras.backend.placeholder( | args.InputTensor = keras.backend.placeholder( | ||||
| shape: BatchInputShape, | shape: BatchInputShape, | ||||
| dtype: DType, | dtype: DType, | ||||
| @@ -81,8 +81,8 @@ namespace Tensorflow.Keras.Layers | |||||
| sparse: args.Sparse, | sparse: args.Sparse, | ||||
| ragged: args.Ragged); | ragged: args.Ragged); | ||||
| isPlaceholder = true; | isPlaceholder = true; | ||||
| tf.Context.restore_mode(); | |||||
| } | } | ||||
| // Create an input node to add to self.outbound_node | // Create an input node to add to self.outbound_node | ||||
| @@ -97,8 +97,6 @@ namespace Tensorflow.Keras.Layers | |||||
| typeSpec = new TensorSpec(args.InputTensor.TensorShape, | typeSpec = new TensorSpec(args.InputTensor.TensorShape, | ||||
| dtype: args.InputTensor.dtype, | dtype: args.InputTensor.dtype, | ||||
| name: Name); | name: Name); | ||||
| tf.Context.restore_mode(); | |||||
| } | } | ||||
| public static InputLayer from_config(LayerArgs args) | public static InputLayer from_config(LayerArgs args) | ||||
| @@ -151,23 +151,12 @@ namespace Tensorflow.Keras.Utils | |||||
| // recursively | // recursively | ||||
| CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); | CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); | ||||
| Layer op_layer = null; | |||||
| /*var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||||
| Layer op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||||
| { | { | ||||
| NodeDef = op.node_def, | NodeDef = op.node_def, | ||||
| Constants = constants, | Constants = constants, | ||||
| Name = op.name | Name = op.name | ||||
| });*/ | |||||
| op_layer = op.type switch | |||||
| { | |||||
| // "AddV2" => keras.layers.Add(), | |||||
| _ => new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||||
| { | |||||
| NodeDef = op.node_def, | |||||
| Constants = constants, | |||||
| Name = op.name | |||||
| }) | |||||
| }; | |||||
| }); | |||||
| created_layers.Add(op_layer); | created_layers.Add(op_layer); | ||||
| op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); | op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); | ||||
| processed_ops.Add(op); | processed_ops.Add(op); | ||||