Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,5 @@ site/
/Llama.Mobile/Resources/Raw
/nuget_pack.bat
/temp
LLama.Web/wwwroot/lib/
LLama.Web/Models/
147 changes: 122 additions & 25 deletions LLama.Examples/Examples/MtmdInteractiveModeExecute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ namespace LLama.Examples.Examples
// It uses the interactive executor to inference.
public class MtmdInteractiveModeExecute
{
private sealed record TemplateMarkers(string? AssistantEndMarker, string? AssistantToUserMarker);

public static async Task Run()
{
string multiModalProj = UserSettings.GetMMProjPath();
string modelPath = UserSettings.GetModelPath();
const int maxTokens = 4096;

string? prompt = await File.ReadAllTextAsync("Assets/chat-with-bob.json");

var parameters = new ModelParams(modelPath);

var mtmdParameters = MtmdContextParams.Default();
Expand All @@ -40,8 +40,8 @@ public static async Task Run()

var ex = new InteractiveExecutor(context, clipModel);
var chatHistory = new ChatHistory();
var isFirstTurn = true;

var templateMarkers = ResolveTemplateMarkers(model);
var antiPrompts = GetEffectiveAntiPrompts(templateMarkers);
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize );
Console.WriteLine("Model: {0}", modelPath);
Expand All @@ -54,6 +54,7 @@ public static async Task Run()
Console.WriteLine("Video inputs are not supported (format would be {{c:/video.mp4}}).");
Console.WriteLine("Commands: /exit (return to main menu) | /clear (reset the chat history and KV cache).");
Console.WriteLine("Press Ctrl+c to return to main menu.");
Console.Write("User: ");

void ResetConversation()
{
Expand All @@ -64,7 +65,7 @@ void ResetConversation()
clipModel.ClearMedia();
chatHistory.Messages.Clear();
ex = new InteractiveExecutor(context, clipModel);
Console.WriteLine("User:");
Console.Write("User: ");
}

var inferenceParams = new InferenceParams
Expand All @@ -74,34 +75,34 @@ void ResetConversation()
Temperature = 0.1f
},

AntiPrompts = new List<string> { "User:" },
MaxTokens = maxTokens
AntiPrompts = antiPrompts,
MaxTokens = maxTokens,
DecodeSpecialTokens = ShouldDecodeSpecialTokens(antiPrompts)

};

do
{
if (!isFirstTurn)
{
Console.ForegroundColor = ConsoleColor.Green;
prompt = Console.ReadLine();
Console.WriteLine();
Console.ForegroundColor = ConsoleColor.Green;
var prompt = Console.ReadLine();
Console.WriteLine();

if (prompt == null || prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;
if (prompt == null || prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
break;

if (prompt.Equals("/clear", StringComparison.OrdinalIgnoreCase))
{
ResetConversation();
continue;
}
if (prompt.Equals("/clear", StringComparison.OrdinalIgnoreCase))
{
ResetConversation();
continue;
}
else

if (string.IsNullOrWhiteSpace(prompt))
{
isFirstTurn = false;
Console.Write("User: ");
continue;
}

var userPrompt = prompt ?? string.Empty;
var userPrompt = prompt;

// Evaluate if we have media
//
Expand Down Expand Up @@ -171,7 +172,7 @@ void ResetConversation()
audioList.Add(mediaPath);
}

var embed = clipModel.LoadMedia(mediaPath);
var embed = clipModel.LoadMediaStandalone(mediaPath);
embeds.Add(embed);
}
}
Expand All @@ -183,7 +184,8 @@ void ResetConversation()
Console.Write($"{exception.Message}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Please try again.");
clipModel.ClearMedia();
foreach (var embed in embeds)
embed.Dispose();
break;
}

Expand Down Expand Up @@ -236,8 +238,9 @@ void ResetConversation()
Console.Write(text);
responseBuilder.Append(text);
}
Console.Write(" ");
Console.WriteLine();
chatHistory.AddMessage(AuthorRole.Assistant, responseBuilder.ToString());
Console.Write("User: ");

}
while(true);
Expand All @@ -259,6 +262,45 @@ private static string BuildChatDelta(LLamaWeights model, ChatHistory history, st
return delta;
}

private static List<string> GetEffectiveAntiPrompts(TemplateMarkers templateMarkers)
{
var antiPrompts = new List<string>();
AddMarker(antiPrompts, templateMarkers.AssistantEndMarker);
AddMarker(antiPrompts, templateMarkers.AssistantToUserMarker);

if (antiPrompts.Count == 0)
antiPrompts.Add("User:");

return antiPrompts;
}

private static void AddMarker(List<string> values, string? marker)
{
if (string.IsNullOrWhiteSpace(marker))
return;

if (!values.Contains(marker))
values.Add(marker);

var trimmedMarker = marker.Trim();
if (!string.IsNullOrWhiteSpace(trimmedMarker) && !values.Contains(trimmedMarker))
values.Add(trimmedMarker);
}

private static bool ShouldDecodeSpecialTokens(IReadOnlyList<string> antiPrompts)
{
foreach (var antiPrompt in antiPrompts)
{
if (string.IsNullOrWhiteSpace(antiPrompt))
continue;

if (antiPrompt.Contains('<', StringComparison.Ordinal) || antiPrompt.Contains('>', StringComparison.Ordinal))
return true;
}

return false;
}

private static string FormatChatHistory(LLamaWeights model, ChatHistory history, bool addAssistant)
{
var template = new LLamaTemplate(model.NativeHandle)
Expand All @@ -271,5 +313,60 @@ private static string FormatChatHistory(LLamaWeights model, ChatHistory history,

return LLamaTemplate.Encoding.GetString(template.Apply());
}

private static TemplateMarkers ResolveTemplateMarkers(LLamaWeights model)
{
const string userMarkerA = "__LLAMA_USER_A__";
const string assistantMarkerA = "__LLAMA_ASSISTANT_A__";
const string userMarkerB = "__LLAMA_USER_B__";

try
{
var assistantTemplate = new LLamaTemplate(model.NativeHandle)
{
AddAssistant = false
};
assistantTemplate.Add("user", userMarkerA);
assistantTemplate.Add("assistant", assistantMarkerA);

var assistantRendered = LLamaTemplate.Encoding.GetString(assistantTemplate.Apply());
var assistantIndex = assistantRendered.IndexOf(assistantMarkerA, StringComparison.Ordinal);
if (assistantIndex < 0)
return new TemplateMarkers(null, null);

var assistantEndMarker = assistantRendered[(assistantIndex + assistantMarkerA.Length)..];

var conversationTemplate = new LLamaTemplate(model.NativeHandle)
{
AddAssistant = false
};
conversationTemplate.Add("user", userMarkerA);
conversationTemplate.Add("assistant", assistantMarkerA);
conversationTemplate.Add("user", userMarkerB);

var conversationRendered = LLamaTemplate.Encoding.GetString(conversationTemplate.Apply());
var assistantConversationIndex = conversationRendered.IndexOf(assistantMarkerA, StringComparison.Ordinal);
var userIndex = conversationRendered.IndexOf(userMarkerB, StringComparison.Ordinal);
if (assistantConversationIndex < 0 || userIndex <= assistantConversationIndex)
return new TemplateMarkers(NormalizeMarker(assistantEndMarker), null);

var assistantToUserMarker = conversationRendered.Substring(
assistantConversationIndex + assistantMarkerA.Length,
userIndex - (assistantConversationIndex + assistantMarkerA.Length));

return new TemplateMarkers(
NormalizeMarker(assistantEndMarker),
NormalizeMarker(assistantToUserMarker));
}
catch
{
return new TemplateMarkers(null, null);
}
}

private static string? NormalizeMarker(string? marker)
{
return string.IsNullOrWhiteSpace(marker) ? null : marker;
}
}
}
Loading
Loading