diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs index 926e51a9..673a7164 100644 --- a/LLama.Unittest/GrammarParserTest.cs +++ b/LLama.Unittest/GrammarParserTest.cs @@ -68,7 +68,7 @@ namespace LLama.Unittest new LLamaGrammarElement(LLamaGrammarElementType.END, 0), new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), - new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), new LLamaGrammarElement(LLamaGrammarElementType.END, 0), new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), diff --git a/LLama/Grammar/GrammarParser.cs b/LLama/Grammar/GrammarParser.cs index 3e038992..a59d6e1a 100644 --- a/LLama/Grammar/GrammarParser.cs +++ b/LLama/Grammar/GrammarParser.cs @@ -42,7 +42,7 @@ namespace LLama.Grammar private uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) { uint nextId = (uint)state.SymbolIds.Count; - string key = Encoding.UTF8.GetString(src.Slice(0, len).ToArray()); + string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); if (state.SymbolIds.TryGetValue(key, out uint existingId)) { @@ -344,7 +344,7 @@ namespace LLama.Grammar ReadOnlySpan nameEnd = ParseName(src); ReadOnlySpan pos = ParseSpace(nameEnd, false); int nameLen = src.Length - nameEnd.Length; - uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), nameLen); + uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0); string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray()); if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) diff --git a/LLama/Grammar/ParseState.cs b/LLama/Grammar/ParseState.cs index ff42f527..f85fa032 100644 --- a/LLama/Grammar/ParseState.cs +++ b/LLama/Grammar/ParseState.cs @@ -12,7 +12,7 @@ namespace LLama.Grammar /// public class ParseState { - public Dictionary SymbolIds { get; } = new Dictionary(); + public SortedDictionary SymbolIds { get; } = new SortedDictionary(); public List> Rules { get; } = new List>(); public IEnumerable> CRules()