Skip to content

Commit

Permalink
Defer the allocation of matchValue. (#61)
Browse files Browse the repository at this point in the history
matchValue isn't needed on the fast path, so don't allocate it until it's actually needed. Additionally, use a Span-based conversion to UTF-8 bytes to avoid the creation of a temporary char array.
  • Loading branch information
bgrainger authored Nov 9, 2024
1 parent a1b5ca5 commit ab8e72d
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions src/libs/Tiktoken.Core/CoreBPE.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.RegularExpressions;
using Tiktoken.Core;
Expand Down Expand Up @@ -86,12 +87,12 @@ public int CountTokensNative(string text)
var tokens = 0;
#if NET7_0_OR_GREATER
var textSpan = text.AsSpan();
Span<byte> pieceBytes = stackalloc byte[128];
#endif

#if NET7_0_OR_GREATER
foreach (var match in Regex.EnumerateMatches(textSpan))
{
var matchValue = textSpan.Slice(match.Index, match.Length).ToArray();
var fastKey = new string(textSpan.Slice(match.Index, match.Length));
#else
foreach (Match match in Regex.Matches(text))
Expand All @@ -110,7 +111,11 @@ public int CountTokensNative(string text)
continue;
}

#if NET7_0_OR_GREATER
var piece = GetUtf8Bytes(textSpan.Slice(match.Index, match.Length), pieceBytes);
#else
var piece = System.Text.Encoding.UTF8.GetBytes(matchValue);
#endif
if (Encoder.ContainsKey(piece))
{
tokens++;
Expand Down Expand Up @@ -148,6 +153,7 @@ public IReadOnlyCollection<int> EncodeNative(
var tokens = new List<int>();
#if NET7_0_OR_GREATER
var textSpan = text.AsSpan();
Span<byte> pieceBytes = stackalloc byte[128];
#endif

var specialTokens = new List<(int Index, int Length)>(capacity: 32);
Expand Down Expand Up @@ -181,7 +187,6 @@ public IReadOnlyCollection<int> EncodeNative(
#if NET7_0_OR_GREATER
foreach (var match in Regex.EnumerateMatches(textSpan[start..specialStart]))
{
var matchValue = textSpan.Slice(match.Index, match.Length).ToArray();
var fastKey = new string(textSpan.Slice(match.Index, match.Length));
#else
foreach (Match match in Regex.Matches(text[start..specialStart]))
Expand All @@ -199,8 +204,12 @@ public IReadOnlyCollection<int> EncodeNative(
tokens.AddRange(fastTokens);
continue;
}


#if NET7_0_OR_GREATER
var piece = GetUtf8Bytes(textSpan.Slice(match.Index, match.Length), pieceBytes);
#else
var piece = System.Text.Encoding.UTF8.GetBytes(matchValue);
#endif
if (Encoder.TryGetValue(piece, out var token))
{
tokens.Add(token);
Expand Down Expand Up @@ -544,4 +553,20 @@ public byte[] DecodeNative(IReadOnlyCollection<int> tokens)
}
return ret.ToArray();
}

#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static byte[] GetUtf8Bytes(ReadOnlySpan<char> text, Span<byte> scratch)
{
// check if text can be decoded into the buffer; each UTF-16 char can become at most 3 UTF-8 bytes
if (text.Length * 3 < scratch.Length)
{
return scratch[..System.Text.Encoding.UTF8.GetBytes(text, scratch)].ToArray();
}
else
{
return System.Text.Encoding.UTF8.GetBytes(text.ToArray());
}
}
#endif
}

0 comments on commit ab8e72d

Please sign in to comment.