using NumSharp; using System; using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow; using Tensorflow.Graphs; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { public class TensorFlowOpLayer : Layer { TensorFlowOpLayerArgs args; Dictionary constants => args.Constants; NodeDef node_def => args.NodeDef; static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"; public string OpType => node_def.Op; public TensorFlowOpLayer(TensorFlowOpLayerArgs args) : base(new LayerArgs { Name = TF_OP_LAYER_NAME_PREFIX + args.Name, Trainable = args.Trainable, DType = args.DType, Autocast = false }) { this.args = args; built = true; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { if (tf.Context.executing_eagerly()) return _defun_call(inputs); return MakOp(inputs); } [AutoGraph] Tensors _defun_call(Tensors inputs) => MakOp(inputs); Tensors MakOp(Tensors inputs) { var graph = inputs.graph; graph.as_default(); foreach (var (index, constant) in enumerate(constants)) { var value = constant_op.constant(constant, name: node_def.Input[index]); inputs.Insert(index, value); } var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]); var op = graph._create_op_from_tf_operation(c_op); op._control_flow_post_processing(); // Record the gradient because custom-made ops don't go through the // code-gen'd eager call path var op_type = op.node_def.Op; tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); graph.Exit(); return op.outputs; } public Layer GetOpLayer(TensorFlowOpLayerArgs args) => new TensorFlowOpLayer(args); } }