From ab8e72d8f77cb052a7f808b2cf36c3d3ae49409d Mon Sep 17 00:00:00 2001 From: Bradley Grainger Date: Sat, 9 Nov 2024 09:46:43 -0800 Subject: [PATCH] Defer the allocation of matchValue. (#61) matchValue isn't needed on the fast path, so don't allocate it until it's actually needed. Additionally, use a Span-based conversion to UTF-8 bytes to avoid the creation of a temporary char array. --- src/libs/Tiktoken.Core/CoreBPE.cs | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/src/libs/Tiktoken.Core/CoreBPE.cs b/src/libs/Tiktoken.Core/CoreBPE.cs index c10bd1c..88de98c 100644 --- a/src/libs/Tiktoken.Core/CoreBPE.cs +++ b/src/libs/Tiktoken.Core/CoreBPE.cs @@ -1,4 +1,5 @@ using System.Collections.Concurrent; +using System.Runtime.CompilerServices; using System.Text; using System.Text.RegularExpressions; using Tiktoken.Core; @@ -86,12 +87,12 @@ public int CountTokensNative(string text) var tokens = 0; #if NET7_0_OR_GREATER var textSpan = text.AsSpan(); + Span pieceBytes = stackalloc byte[128]; #endif #if NET7_0_OR_GREATER foreach (var match in Regex.EnumerateMatches(textSpan)) { - var matchValue = textSpan.Slice(match.Index, match.Length).ToArray(); var fastKey = new string(textSpan.Slice(match.Index, match.Length)); #else foreach (Match match in Regex.Matches(text)) @@ -110,7 +111,11 @@ public int CountTokensNative(string text) continue; } +#if NET7_0_OR_GREATER + var piece = GetUtf8Bytes(textSpan.Slice(match.Index, match.Length), pieceBytes); +#else var piece = System.Text.Encoding.UTF8.GetBytes(matchValue); +#endif if (Encoder.ContainsKey(piece)) { tokens++; @@ -148,6 +153,7 @@ public IReadOnlyCollection EncodeNative( var tokens = new List(); #if NET7_0_OR_GREATER var textSpan = text.AsSpan(); + Span pieceBytes = stackalloc byte[128]; #endif var specialTokens = new List<(int Index, int Length)>(capacity: 32); @@ -181,7 +187,6 @@ public IReadOnlyCollection EncodeNative( #if NET7_0_OR_GREATER foreach (var match in Regex.EnumerateMatches(textSpan[start..specialStart])) { - var matchValue = textSpan.Slice(match.Index, match.Length).ToArray(); var fastKey = new string(textSpan.Slice(match.Index, match.Length)); #else foreach (Match match in Regex.Matches(text[start..specialStart])) @@ -199,8 +204,12 @@ public IReadOnlyCollection EncodeNative( tokens.AddRange(fastTokens); continue; } - + +#if NET7_0_OR_GREATER + var piece = GetUtf8Bytes(textSpan.Slice(match.Index, match.Length), pieceBytes); +#else var piece = System.Text.Encoding.UTF8.GetBytes(matchValue); +#endif if (Encoder.TryGetValue(piece, out var token)) { tokens.Add(token); @@ -544,4 +553,20 @@ public byte[] DecodeNative(IReadOnlyCollection tokens) } return ret.ToArray(); } + +#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static byte[] GetUtf8Bytes(ReadOnlySpan text, Span scratch) + { + // check if text can be decoded into the buffer; each UTF-16 char can become at most 3 UTF-8 bytes + if (text.Length * 3 < scratch.Length) + { + return scratch[..System.Text.Encoding.UTF8.GetBytes(text, scratch)].ToArray(); + } + else + { + return System.Text.Encoding.UTF8.GetBytes(text.ToArray()); + } + } +#endif } \ No newline at end of file