| @@ -129,6 +129,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| [DebuggerStepThrough] | |||||
| [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | ||||
| public static TOut tf_with<TIn, TOut>(TIn py, Func<TIn, TOut> action) where TIn : IObjectLife | public static TOut tf_with<TIn, TOut>(TIn py, Func<TIn, TOut> action) where TIn : IObjectLife | ||||
| { | { | ||||
| @@ -56,9 +56,9 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| // Instantiate an input layer. | // Instantiate an input layer. | ||||
| var x = keras.layers.Input( | var x = keras.layers.Input( | ||||
| batch_shape: batch_shape, | |||||
| dtype: dtype, | |||||
| name: layer.name + "_input"); | |||||
| batch_shape: batch_shape, | |||||
| dtype: dtype, | |||||
| name: layer.name + "_input"); | |||||
| // This will build the current layer | // This will build the current layer | ||||
| // and create the node connecting the current layer | // and create the node connecting the current layer | ||||
| @@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Engine | |||||
| if (set_inputs) | if (set_inputs) | ||||
| { | { | ||||
| // If an input layer (placeholder) is available. | // If an input layer (placeholder) is available. | ||||
| // outputs = layer._inbound_nodes; | |||||
| // outputs = layer.inbound_nodes; | |||||
| } | } | ||||
| } | } | ||||
| @@ -106,6 +106,7 @@ namespace Tensorflow.Keras.Layers | |||||
| VariableScope scope = null) | VariableScope scope = null) | ||||
| { | { | ||||
| var input_list = inputs; | var input_list = inputs; | ||||
| var input = inputs[0]; | |||||
| Tensor outputs = null; | Tensor outputs = null; | ||||
| // We will attempt to build a TF graph if & only if all inputs are symbolic. | // We will attempt to build a TF graph if & only if all inputs are symbolic. | ||||
| @@ -139,6 +140,7 @@ namespace Tensorflow.Keras.Layers | |||||
| _maybe_build(inputs[0]); | _maybe_build(inputs[0]); | ||||
| outputs = call(inputs[0], training: training); | outputs = call(inputs[0], training: training); | ||||
| (input, outputs) = _set_connectivity_metadata_(input, outputs); | |||||
| _handle_activity_regularization(inputs[0], outputs); | _handle_activity_regularization(inputs[0], outputs); | ||||
| _set_mask_metadata(inputs[0], outputs, null); | _set_mask_metadata(inputs[0], outputs, null); | ||||
| }); | }); | ||||
| @@ -147,6 +149,12 @@ namespace Tensorflow.Keras.Layers | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs) | |||||
| { | |||||
| //_add_inbound_node(input_tensors: inputs, output_tensors: outputs); | |||||
| return (inputs, outputs); | |||||
| } | |||||
| private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | ||||
| { | { | ||||
| //if(_activity_regularizer != null) | //if(_activity_regularizer != null) | ||||
| @@ -605,8 +605,9 @@ namespace Tensorflow | |||||
| if (axis != 0) | if (axis != 0) | ||||
| return gen_array_ops.gather_v2(@params, indices, axis, name: name); | return gen_array_ops.gather_v2(@params, indices, axis, name: name); | ||||
| if (@params is ResourceVariable variable) | |||||
| return variable.sparse_read(); | |||||
| if (@params is ResourceVariable variable && | |||||
| indices is Tensor indices_tensor) | |||||
| return variable.sparse_read(indices_tensor, name); | |||||
| return gen_array_ops.gather_v2(@params, indices, axis, name: name); | return gen_array_ops.gather_v2(@params, indices, axis, name: name); | ||||
| } | } | ||||
| @@ -73,5 +73,20 @@ namespace Tensorflow | |||||
| return _op.output; | return _op.output; | ||||
| } | } | ||||
| public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype, | |||||
| int batch_dims = 0, bool validate_indices = true, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("ResourceGather", name, new | |||||
| { | |||||
| resource, | |||||
| indices, | |||||
| dtype, | |||||
| batch_dims, | |||||
| validate_indices | |||||
| }); | |||||
| return _op.output; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -216,6 +216,18 @@ namespace Tensorflow | |||||
| _dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype")); | _dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype")); | ||||
| } | } | ||||
| public Tensor sparse_read(Tensor indices, string name = "Gather") | |||||
| { | |||||
| return tf_with(ops.name_scope(name), scope => | |||||
| { | |||||
| name = scope; | |||||
| var value = gen_resource_variable_ops.resource_gather( | |||||
| _handle, indices, dtype: _dtype, name: name); | |||||
| return array_ops.identity(value); | |||||
| }); | |||||
| } | |||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}"; | return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}"; | ||||