From 2295a04ecd3af4b73383e4f17dec29b6e902ab3b Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sun, 7 May 2023 22:49:57 +0800 Subject: [PATCH] fix: revise wrong behaviors of op code generator. --- Tensorflow.CodeGen/FunctionGenerator.cs | 284 +++++++++++++------ Tensorflow.CodeGen/GenOpsWriter.cs | 4 +- Tensorflow.CodeGen/OpClassifier.cs | 30 +- Tensorflow.CodeGen/Program.cs | 2 + Tensorflow.CodeGen/Tensorflow.CodeGen.csproj | 5 +- Tensorflow.CodeGen/Utils.cs | 15 +- 6 files changed, 242 insertions(+), 98 deletions(-) diff --git a/Tensorflow.CodeGen/FunctionGenerator.cs b/Tensorflow.CodeGen/FunctionGenerator.cs index d4520307..b3b695c5 100644 --- a/Tensorflow.CodeGen/FunctionGenerator.cs +++ b/Tensorflow.CodeGen/FunctionGenerator.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Linq.Expressions; using System.Reflection.Metadata.Ecma335; using System.Text; using System.Threading.Tasks; @@ -16,17 +17,17 @@ namespace Tensorflow.CodeGen // TODO: add descriptions sb.Append("public static "); int outputArgsCount = op.OutputArg.Count; - if (outputArgsCount > 1) + if (outputArgsCount == 0) { - sb.Append("Tensor[] "); + sb.Append("Operation "); } - else if (outputArgsCount == 1) + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) { sb.Append("Tensor "); } else { - sb.Append("Operation "); + sb.Append("Tensor[] "); } string funcName = Utils.ConvertToUnderscore(op.Name); var token = SyntaxFactory.ParseToken(funcName); @@ -42,6 +43,17 @@ namespace Tensorflow.CodeGen // begin to write main body sb.AppendLine("var _ctx = tf.Context;"); + + var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); + // deal with dynamic default values. + foreach(var (name, expr) in dynamicDefaultValues) + { + sb.AppendLine($"if({name} is null)"); + sb.AppendLine("{"); + sb.AppendLine($"{name} = {expr};"); + sb.AppendLine("}"); + } + sb.AppendLine("if(_ctx.executing_eagerly()){"); if(HasRefArgs(op)) @@ -58,7 +70,7 @@ namespace Tensorflow.CodeGen { sb.AppendLine("return null;"); } - else if (outputArgsCount == 1) + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) { sb.AppendLine("return _fast_path_result[0];"); } @@ -82,6 +94,17 @@ namespace Tensorflow.CodeGen sb.AppendLine("}"); // if + foreach(var (name, type, value) in attrValueDic.Where(x => x.Item2 == "string")) + { + if(value != "NOVALUE") + { + sb.AppendLine($"if({name} is null)"); + sb.AppendLine("{"); + sb.AppendLine($"{name} = {value};"); + sb.AppendLine("}"); + } + } + // begin to use op helper. AppendOpHelperCall(op, sb); sb.AppendLine("var _result = _op.outputs;"); @@ -126,7 +149,7 @@ namespace Tensorflow.CodeGen { sb.AppendLine("return _op;"); } - else if (outputArgsCount == 1) + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) { sb.AppendLine("return _result[0];"); } @@ -160,8 +183,8 @@ namespace Tensorflow.CodeGen sb.Append($"Tensor {argName}, "); } } - var attrValueDic = GetAttrsDefaultValue(op); - foreach (var (key, (typeStr, value)) in attrValueDic) + var attrValueDic = GetAttrsDefaultValue(op, out var dynamicDefaultValues); + foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 == "NOVALUE")) { var token = SyntaxFactory.ParseToken(key); string realKey = key; @@ -169,21 +192,25 @@ namespace Tensorflow.CodeGen { realKey += "_"; } - if (value != "NOVALUE") - { - sb.Append($"{typeStr} {realKey} = {value}, "); - } - else + sb.Append($"{typeStr} {realKey}, "); + } + foreach (var (key, typeStr, value) in attrValueDic.Where(x => x.Item3 != "NOVALUE")) + { + var token = SyntaxFactory.ParseToken(key); + string realKey = key; + if (token.IsKeyword()) { - sb.Append($"{typeStr} {realKey}, "); + realKey += "_"; } + sb.Append($"{typeStr} {realKey} = {value}, "); } sb.Append($"string? name = null"); } public void AppendFastPathExecute(OpDef op, StringBuilder sb) { - sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name, "); + sb.Append($"var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, \"{op.Name}\", name)"); + sb.Append("{ args = new object[]{ "); foreach (var arg in op.InputArg) { string attrArgName = arg.Name; @@ -193,16 +220,23 @@ namespace Tensorflow.CodeGen } sb.Append($"{attrArgName}, "); } - var attrValueDic = GetAttrsDefaultValue(op); - foreach (var (key, _) in attrValueDic) + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') { - sb.Append($"\"{key}\", {key}, "); + sb.Remove(sb.Length - 2, 2); + } + + sb.Append("}, attrs = new Dictionary(){ "); + var attrValueDic = GetAttrsDefaultValue(op, out var _); + foreach (var (key, _, _) in attrValueDic) + { + sb.Append($"[\"{key}\"] = {key}, "); } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') { sb.Remove(sb.Length - 2, 2); } - sb.Append("));\n"); + sb.Append("}});\n"); } public void AppendEagerFallbackCall(OpDef op, StringBuilder sb) @@ -218,8 +252,8 @@ namespace Tensorflow.CodeGen } sb.Append($"{inputArgRealName}, "); } - var attrValueDic = GetAttrsDefaultValue(op); - foreach (var (key, _) in attrValueDic) + var attrValueDic = GetAttrsDefaultValue(op, out var _); + foreach (var (key, _, _) in attrValueDic) { string keyRealName = key; if (SyntaxFactory.ParseToken(keyRealName).IsKeyword()) @@ -233,11 +267,19 @@ namespace Tensorflow.CodeGen public void AppendEagerFallbackDefinition(OpDef op, StringBuilder sb) { - sb.Append("public static Tensor"); + sb.Append("public static "); int outputArgsCount = op.OutputArg.Count; - if (outputArgsCount > 1) + if (outputArgsCount == 0) + { + sb.Append("Operation "); + } + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) + { + sb.Append("Tensor "); + } + else { - sb.Append("[]"); + sb.Append("Tensor[] "); } string opName = op.Name; string funcName = Utils.ConvertToUnderscore(op.Name); @@ -254,24 +296,47 @@ namespace Tensorflow.CodeGen return; } - sb.Append("Tensor[] _inputs_flat = new Tensor[]{"); - foreach (var arg in op.InputArg) + if(op.InputArg.Any(x => !string.IsNullOrEmpty(x.NumberAttr))) { - string realArgName = arg.Name; - if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + sb.AppendLine("List _inputs_flat_list = new();"); + foreach (var arg in op.InputArg) { - realArgName = $"{realArgName}_"; + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName = $"{realArgName}_"; + } + if (string.IsNullOrEmpty(arg.NumberAttr)) + { + sb.AppendLine($"_inputs_flat_list.Add({realArgName});"); + } + else + { + sb.AppendLine($"_inputs_flat_list.AddRange({realArgName});"); + } } - sb.Append($"{realArgName}, "); + sb.AppendLine($"var _inputs_flat = _inputs_flat_list.ToArray();"); } - if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + else { - sb.Remove(sb.Length - 2, 2); + sb.Append("Tensor[] _inputs_flat = new Tensor[]{"); + foreach (var arg in op.InputArg) + { + string realArgName = arg.Name; + if (SyntaxFactory.ParseToken(realArgName).IsKeyword()) + { + realArgName = $"{realArgName}_"; + } + sb.Append($"{realArgName}, "); + } + if (sb[sb.Length - 1] == ' ' && sb[sb.Length - 2] == ',') + { + sb.Remove(sb.Length - 2, 2); + } + sb.Append("};\n"); } - sb.Append("};\n"); sb.Append("object[] _attrs = new object[]{"); - var attrValueDic = GetAttrsDefaultValue(op); foreach (var attr in op.Attr) { if (attr.Type == "type") @@ -293,27 +358,15 @@ namespace Tensorflow.CodeGen } if (!found) { - if (attr.Name.StartsWith("T") && attr.Name.Length > 1) - { - string paramName = attr.Name.Substring(1); - if (SyntaxFactory.ParseToken(paramName).IsKeyword()) - { - paramName = $"{paramName}_"; - } - sb.Append($"\"{attr.Name}\", {paramName}.dtype, "); - } - else + string attrRealName = attr.Name; + if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) { - string attrRealName = attr.Name; - if (SyntaxFactory.ParseToken(attrRealName).IsKeyword()) - { - attrRealName = $"{attrRealName}_"; - } - sb.Append($"\"{attr.Name}\", {attrRealName}, "); + attrRealName = $"{attrRealName}_"; } + sb.Append($"\"{attr.Name}\", {attrRealName}, "); } } - else if(attr.Type == "int" && (op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name))) + else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name)) { bool found = false; foreach (var arg in op.InputArg) @@ -355,7 +408,7 @@ namespace Tensorflow.CodeGen { sb.AppendLine("return null;"); } - else if (outputArgsCount == 1) + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) { sb.AppendLine("return _result[0];"); } @@ -386,8 +439,8 @@ namespace Tensorflow.CodeGen sb.Append($"Tensor {argName}, "); } } - var attrValueDic = GetAttrsDefaultValue(op); - foreach (var (key, (typeStr, _)) in attrValueDic) + var attrValueDic = GetAttrsDefaultValue(op, out var _); + foreach (var (key, typeStr, _) in attrValueDic) { var token = SyntaxFactory.ParseToken(key); string realKey = key; @@ -412,18 +465,19 @@ namespace Tensorflow.CodeGen } sb.AppendLine($"keywords[\"{arg.Name}\"] = {realArgName};"); } - var attrValueDic = GetAttrsDefaultValue(op); - foreach (var (key, _) in attrValueDic) + var attrValueDic = GetAttrsDefaultValue(op, out var _); + foreach (var (key, _, _) in attrValueDic) { - sb.Append($"keywords[\"{key}\"] = {key};"); + sb.AppendLine($"keywords[\"{key}\"] = {key};"); } sb.AppendLine($"var _op = tf.OpDefLib._apply_op_helper(\"{op.Name}\", name, keywords);"); } - // key, (type string, default value) - public Dictionary GetAttrsDefaultValue(OpDef op) + // name, type string, default value + public List<(string, string, string)> GetAttrsDefaultValue(OpDef op, out Dictionary dynamicDefaultValues) { - Dictionary dic = new(); + dynamicDefaultValues = new(); + List<(string, string, string)> res = new(); foreach (var attr in op.Attr) { if (attr.Type == "type") @@ -435,111 +489,177 @@ namespace Tensorflow.CodeGen { string name = Enum.GetName(typeof(TF_DataType), attr.DefaultValue.Type.as_tf_dtype()); string enumPath = typeof(TF_DataType).Name + "." + name; - dic[attr.Name] = ("TF_DataType", enumPath); + res.Add((attr.Name, "TF_DataType", enumPath)); } else { - dic[attr.Name] = ("TF_DataType", "NOVALUE"); + res.Add((attr.Name, "TF_DataType", "NOVALUE")); } } } else if (attr.Type == "int") { - if(op.InputArg.Any(x => x.NumberAttr == attr.Name) || op.OutputArg.Any(x => x.NumberAttr == attr.Name)) + if(op.InputArg.Any(x => x.NumberAttr == attr.Name)) { continue; } if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.I) { - dic[attr.Name] = ("int", attr.DefaultValue.I.ToString()); + res.Add((attr.Name, "int", attr.DefaultValue.I.ToString())); } else { - dic[attr.Name] = ("int", "0"); + res.Add((attr.Name, "int", "0")); } } else if (attr.Type == "float") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.F) { - dic[attr.Name] = ("float", attr.DefaultValue.F.ToString() + "f"); + res.Add((attr.Name, "float", attr.DefaultValue.F.ToString() + "f")); } else { - dic[attr.Name] = ("float", "NOVALUE"); + res.Add((attr.Name, "float", "NOVALUE")); } } else if (attr.Type == "string") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) { - dic[attr.Name] = ("string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\""); + res.Add((attr.Name, "string", $"\"{attr.DefaultValue.S.ToStringUtf8()}\"")); } else { - dic[attr.Name] = ("string", "NOVALUE"); + res.Add((attr.Name, "string", "NOVALUE")); } } else if (attr.Type == "bool") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.B) { - dic[attr.Name] = ("bool", attr.DefaultValue.B.ToString().ToLower()); + res.Add((attr.Name, "bool", attr.DefaultValue.B.ToString().ToLower())); } else { - dic[attr.Name] = ("bool", "NOVALUE"); + res.Add((attr.Name, "bool", "NOVALUE")); } } else if (attr.Type == "shape") { if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Shape) { - dic[attr.Name] = ("Shape", $"null"); + if (attr.DefaultValue.Shape.UnknownRank) + { + res.Add((attr.Name, "Shape", $"null")); + } + else + { + Shape shape = new Shape(attr.DefaultValue.Shape); + string expression = $"new Shape({string.Join(", ", shape.dims)})"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "Shape", $"null")); + } } else { - dic[attr.Name] = ("Shape", "NOVALUE"); + res.Add((attr.Name, "Shape", "NOVALUE")); } } else if (attr.Type == "list(type)") { - dic[attr.Name] = ("TF_DataType[]", "NOVALUE"); + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.Type) + { + values.Add(value.as_tf_dtype()); + } + string expression = "new TF_DataType[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "TF_DataType[]", $"null")); + } + else + { + res.Add((attr.Name, "TF_DataType[]", "NOVALUE")); + } } else if (attr.Type == "list(shape)") { - dic[attr.Name] = ("Shape[]", "NOVALUE"); + res.Add((attr.Name, "Shape[]", "NOVALUE")); } else if (attr.Type == "list(string)") { - dic[attr.Name] = ("string[]", "NOVALUE"); + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.S) + { + values.Add(value.ToStringUtf8()); + } + string expression = "new string[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "string[]", $"null")); + } + else + { + res.Add((attr.Name, "string[]", "NOVALUE")); + } } else if (attr.Type == "list(int)") { - dic[attr.Name] = ("int[]", "NOVALUE"); + if(attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) + { + List values = new(); + foreach(var value in attr.DefaultValue.List.I) + { + values.Add((int)value); + } + string expression = "new int[]{" + $"{string.Join(", ", values)}" +"}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "int[]", $"null")); + } + else + { + res.Add((attr.Name, "int[]", "NOVALUE")); + } } else if (attr.Type == "list(float)") { - dic[attr.Name] = ("float[]", "NOVALUE"); + if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) + { + List values = new(); + foreach (var value in attr.DefaultValue.List.F) + { + values.Add(value); + } + string expression = "new float[]{" + $"{string.Join(", ", values)}" + "}"; + dynamicDefaultValues[attr.Name] = expression; + res.Add((attr.Name, "float[]", $"null")); + } + else + { + res.Add((attr.Name, "float[]", "NOVALUE")); + } } else if (attr.Type == "func") { - dic[attr.Name] = ("Func", "NOVALUE"); + res.Add((attr.Name, "Func", "NOVALUE")); } else if (attr.Type == "list(func)") { - dic[attr.Name] = ("Func[]", "NOVALUE"); + res.Add((attr.Name, "Func[]", "NOVALUE")); } else if (attr.Type == "tensor") { - dic[attr.Name] = ("TensorProto", "NOVALUE"); + res.Add((attr.Name, "TensorProto", "NOVALUE")); } else { throw new NotImplementedException(); } } - return dic; + return res; } private static bool HasRefArgs(OpDef op) diff --git a/Tensorflow.CodeGen/GenOpsWriter.cs b/Tensorflow.CodeGen/GenOpsWriter.cs index 83ca6e0b..2cd7bca5 100644 --- a/Tensorflow.CodeGen/GenOpsWriter.cs +++ b/Tensorflow.CodeGen/GenOpsWriter.cs @@ -21,7 +21,7 @@ namespace Tensorflow.CodeGen var opDefs = ReadAllOpDefs(opDefFilename); _opMap = opDefs.Op.ToDictionary( x => Tensorflow.CodeGen.Utils.ConvertToUnderscore(x.Name), x => x); - _opClassifier = new OpClassifier(pythonFilesDirectory); + _opClassifier = new OpClassifier(pythonFilesDirectory, opDefs.Op.Select(x => Utils.ConvertToUnderscore(x.Name))); } public void WriteAll() @@ -45,7 +45,7 @@ namespace Tensorflow.CodeGen sb.AppendLine(); // Write class name - sb.AppendLine($"internal static class {target}"); + sb.AppendLine($"public static class {target}"); sb.AppendLine("{"); foreach(var funcName in set) diff --git a/Tensorflow.CodeGen/OpClassifier.cs b/Tensorflow.CodeGen/OpClassifier.cs index 2ea2f35e..eaad3fec 100644 --- a/Tensorflow.CodeGen/OpClassifier.cs +++ b/Tensorflow.CodeGen/OpClassifier.cs @@ -10,27 +10,39 @@ namespace Tensorflow.CodeGen public class OpClassifier { private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$"; - private static readonly string _pythonFunctionPattern = @"def\s+(\w+)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*\w+\s*=None\s*\):"; + private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):"; private Dictionary> _opSet = new(); public Dictionary> OpSet => _opSet; - public OpClassifier(string pythonFileFolder) + public OpClassifier(string pythonFileFolder, IEnumerable funcNames) { DirectoryInfo directory = new DirectoryInfo(pythonFileFolder); + Dictionary fileContentMap = new(); foreach (FileInfo file in directory.GetFiles()) { if (Regex.IsMatch(file.Name, _filenamePattern)) { + Console.WriteLine(file.Name); string filenamePrefix = file.Name.Split('.')[0]; string content = File.ReadAllText(file.FullName); - var matches = Regex.Matches(content, _pythonFunctionPattern); - foreach(Match match in matches) + fileContentMap[filenamePrefix] = content; + } + } + + foreach(var funcName in funcNames) + { + Console.WriteLine(funcName); + string funcPattern = @$"^def\s+{funcName}\("; + string fallbackFuncPattern = @$"^def\s+{funcName}_eager_fallback\("; + foreach (var (target, content) in fileContentMap) + { + if(content.Contains($"def {funcName}") && content.Contains($"def {funcName}_eager_fallback")) + { + _opSet.SetDefault(target, new HashSet()).Add(funcName); + } + else if (content.Contains($"def _{funcName}") && content.Contains($"def _{funcName}_eager_fallback")) { - var funcName = match.Groups[1].Value; - if (!funcName.EndsWith("_eager_fallback")) - { - _opSet.SetDefault(filenamePrefix, new HashSet()).Add(funcName); - } + _opSet.SetDefault(target, new HashSet()).Add(funcName); } } } diff --git a/Tensorflow.CodeGen/Program.cs b/Tensorflow.CodeGen/Program.cs index d46dcdcb..a26031cb 100644 --- a/Tensorflow.CodeGen/Program.cs +++ b/Tensorflow.CodeGen/Program.cs @@ -5,6 +5,8 @@ using System.Text; using System.Xml.Linq; using Tensorflow.CodeGen; +//Console.WriteLine(Utils.ConvertToUnderscore("LRN")); + GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); diff --git a/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj b/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj index 61273d01..a052eb69 100644 --- a/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj +++ b/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj @@ -1,4 +1,4 @@ - + Exe @@ -9,10 +9,11 @@ + - + diff --git a/Tensorflow.CodeGen/Utils.cs b/Tensorflow.CodeGen/Utils.cs index 8cf21dee..608222e0 100644 --- a/Tensorflow.CodeGen/Utils.cs +++ b/Tensorflow.CodeGen/Utils.cs @@ -18,15 +18,24 @@ namespace Tensorflow.CodeGen StringBuilder result = new StringBuilder(); - int state = 0; // the previous char was not lowered. + int state = 1; // the previous char was not lowered. for (int i = 0; i < input.Length; i++) { char current = input[i]; // 首字母不需要添加下划线 - if (i != 0 && char.IsUpper(current)) + if (char.IsUpper(current)) { - if(state == 0) + if(i > 0) + { + char pre = input[i - 1]; + if (char.IsDigit(pre)) + { + result.Append(char.ToLower(current)); + continue; + } + } + if (state == 0) { result.Append("_"); state = 1;