From 11416ebc3a941b73469d20023ba883569a17dc33 Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Wed, 18 Mar 2026 18:35:50 +1100 Subject: [PATCH 01/13] Call babble_trainer for inference --- .gitmodules | 3 + Baballonia.sln | 10 ++- src/Baballonia/Baballonia.csproj | 1 + src/Baballonia/Services/EyePipelineManager.cs | 33 ++++----- .../Inference/EyeProcessingPipeline.cs | 69 +++++++++++++++---- .../Services/ParameterSenderService.cs | 8 +-- src/Baballonia/Utils/Utils.cs | 2 +- .../SplitViewPane/CalibrationViewModel.cs | 8 +-- src/Baballonia/Views/HomePageView.axaml.cs | 6 +- src/VRCFaceTracking.Baballonia/BabbleOSC.cs | 12 ++-- src/babble_trainer | 1 + 11 files changed, 105 insertions(+), 48 deletions(-) create mode 160000 src/babble_trainer diff --git a/.gitmodules b/.gitmodules index 36dad3ba..2024a7ff 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "src/espflash"] path = src/espflash url = https://github.com/esp-rs/espflash.git +[submodule "src/babble_trainer"] + path = src/babble_trainer + url = https://github.com/acmdf/babble_trainer diff --git a/Baballonia.sln b/Baballonia.sln index e151a5ef..1ea01223 100644 --- a/Baballonia.sln +++ b/Baballonia.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 17 -VisualStudioVersion = 17.13.35919.96 +# Visual Studio Version 18 +VisualStudioVersion = 18.4.11605.240 stable MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Baballonia", "src\Baballonia\Baballonia.csproj", "{00505DCC-588E-4E46-8F91-2AF6A88CBC78}" EndProject @@ -47,6 +47,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Baballonia.CaptureBin.IO", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Baballonia.LibV4L2Capture", "src\Baballonia.LibV4L2Capture\Baballonia.LibV4L2Capture.csproj", "{02E9F6A2-A443-491D-93FF-6F002F3C494F}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "babble_trainer", "src\babble_trainer\bindings\bindings.csproj", "{692165EE-1668-066F-9D57-CC06413C6A55}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -126,6 +128,10 @@ Global {02E9F6A2-A443-491D-93FF-6F002F3C494F}.Debug|Any CPU.Build.0 = Debug|Any CPU {02E9F6A2-A443-491D-93FF-6F002F3C494F}.Release|Any CPU.ActiveCfg = Release|Any CPU {02E9F6A2-A443-491D-93FF-6F002F3C494F}.Release|Any CPU.Build.0 = Release|Any CPU + {692165EE-1668-066F-9D57-CC06413C6A55}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {692165EE-1668-066F-9D57-CC06413C6A55}.Debug|Any CPU.Build.0 = Debug|Any CPU + {692165EE-1668-066F-9D57-CC06413C6A55}.Release|Any CPU.ActiveCfg = Release|Any CPU + {692165EE-1668-066F-9D57-CC06413C6A55}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/Baballonia/Baballonia.csproj b/src/Baballonia/Baballonia.csproj index ea39b246..e26b3431 100644 --- a/src/Baballonia/Baballonia.csproj +++ b/src/Baballonia/Baballonia.csproj @@ -153,6 +153,7 @@ + diff --git a/src/Baballonia/Services/EyePipelineManager.cs b/src/Baballonia/Services/EyePipelineManager.cs index 8d3c9e37..f61de78c 100644 --- a/src/Baballonia/Services/EyePipelineManager.cs +++ b/src/Baballonia/Services/EyePipelineManager.cs @@ -54,29 +54,26 @@ public void InitializePipeline() public async Task LoadInferenceAsync() { - var inf = await Task.Run(CreateInference); - _pipeline.InferenceService = inf; + await Task.Run(CreateInference); } - private DefaultInferenceRunner CreateInference() + private void CreateInference() { - const string defaultEyeModelName = "eyeModel.onnx"; - var eyeModelName = _localSettings.ReadSetting("EyeHome_EyeModel", defaultEyeModelName); - var eyeModelPath = Path.Combine(AppContext.BaseDirectory, eyeModelName); - - if (File.Exists(eyeModelPath)) return _inferenceFactory.Create(eyeModelPath); - _logger.LogError("{} Does not exists, Loading default...", eyeModelPath); - - eyeModelName = defaultEyeModelName; - eyeModelPath = Path.Combine(AppContext.BaseDirectory, eyeModelName); - - return _inferenceFactory.Create(eyeModelPath); - } + var eyeModelName = _localSettings.ReadSetting("EyeHome_EyeModel"); + if (eyeModelName != null) + { + var eyeModelPath = Path.Combine(AppContext.BaseDirectory, eyeModelName); + var load_error = _pipeline.LoadInference(eyeModelPath); - public void LoadInference() - { - _pipeline.InferenceService = CreateInference(); + if (load_error != null) + { + _logger.LogError($"Inference error: {load_error}"); + } else + { + _logger.LogInformation($"Inference loaded from {eyeModelPath}"); + } + } } public void LoadFilter() diff --git a/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs b/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs index d6ad4728..24e2eeb6 100644 --- a/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs +++ b/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs @@ -1,6 +1,8 @@ using Baballonia.Services.events; using Baballonia.Services.Inference.Enums; using System; +using babble_model.Net.Sys; +using OpenCvSharp; namespace Baballonia.Services.Inference; @@ -17,6 +19,45 @@ public EyeProcessingPipeline(IEyePipelineEventBus eyePipelineEventBus) public bool StabilizeEyes { get; set; } = true; + public unsafe string? LoadInference(string modelPath) + { + var bytes = System.Text.Encoding.UTF8.GetBytes(modelPath); + fixed (byte* ptr = bytes) + { + var output = NativeMethods.loadModel(ptr); + + if (output.is_error) { + var errorMsg = new string((sbyte*)output.value.error_message); + + NativeMethods.freeModelOutputResult(output); + + return errorMsg; + } + + return null; + } + } + + unsafe float[]? RunInference(Mat collected) + { + var res = NativeMethods.infer(collected.ExtractChannel(0).DataPointer, collected.ExtractChannel(1).DataPointer); + + if (res.is_error) + { + string errorMsg = new String((sbyte*)res.value.error_message); + Console.WriteLine($"Inference error: {errorMsg}"); + return null; + } + + var output = res.value.model_output; + + float[] inferenceResult = { output.pitch_l, output.yaw_l, output.blink_l, output.eyebrow_l, output.eyewide_l, output.pitch_r, output.yaw_r, output.blink_r, output.eyebrow_r, output.eyewide_r }; + + NativeMethods.freeModelOutputResult(res); + + return inferenceResult; + } + public float[]? RunUpdate() { var frame = VideoSource?.GetFrame(ColorType.Gray8); @@ -39,13 +80,9 @@ public EyeProcessingPipeline(IEyePipelineEventBus eyePipelineEventBus) if (collected == null) return null; - if (InferenceService == null) - return null; - - ImageConverter?.Convert(collected, InferenceService.GetInputTensor()); + var inferenceResult = RunInference(collected); - var inferenceResult = InferenceService?.Run(); - if(inferenceResult == null) + if (inferenceResult == null) return null; if (Filter != null) @@ -74,10 +111,14 @@ private bool ProcessExpressions(ref float[] arKitExpressions) var leftPitch = arKitExpressions[0] * mulY - mulY / 2; var leftYaw = arKitExpressions[1] * mulV - mulV / 2; var leftLid = 1 - arKitExpressions[2]; + var leftEyebrow = arKitExpressions[3]; + var leftEyewide = arKitExpressions[4]; - var rightPitch = arKitExpressions[3] * mulY - mulY / 2; - var rightYaw = arKitExpressions[4] * mulV - mulV / 2; - var rightLid = 1 - arKitExpressions[5]; + var rightPitch = arKitExpressions[5] * mulY - mulY / 2; + var rightYaw = arKitExpressions[6] * mulV - mulV / 2; + var rightLid = 1 - arKitExpressions[7]; + var rightEyebrow = arKitExpressions[8]; + var rightEyewide = arKitExpressions[9]; var eyeY = (leftPitch * leftLid + rightPitch * rightLid) / (leftLid + rightLid); @@ -101,9 +142,13 @@ private bool ProcessExpressions(ref float[] arKitExpressions) convertedExpressions[0] = rightEyeYawCorrected; // left pitch convertedExpressions[1] = eyeY; // left yaw convertedExpressions[2] = rightLid; // left lid - convertedExpressions[3] = leftEyeYawCorrected; // right pitch - convertedExpressions[4] = eyeY; // right yaw - convertedExpressions[5] = leftLid; // right lid + convertedExpressions[3] = leftEyebrow; // left eyebrow + convertedExpressions[4] = leftEyewide; // left eye wide + convertedExpressions[5] = leftEyeYawCorrected; // right pitch + convertedExpressions[6] = eyeY; // right yaw + convertedExpressions[7] = leftLid; // right lid + convertedExpressions[8] = leftEyebrow; // right eyebrow + convertedExpressions[9] = leftEyewide; // right eye wide arKitExpressions = convertedExpressions; diff --git a/src/Baballonia/Services/ParameterSenderService.cs b/src/Baballonia/Services/ParameterSenderService.cs index ff56038f..9dcc88c1 100644 --- a/src/Baballonia/Services/ParameterSenderService.cs +++ b/src/Baballonia/Services/ParameterSenderService.cs @@ -32,15 +32,15 @@ public class ParameterSenderService : BackgroundService { "LeftEyeX", "/LeftEyeX" }, { "LeftEyeY", "/LeftEyeY" }, { "LeftEyeLid", "/LeftEyeLid" }, - //{ "LeftEyeWiden", "/LeftEyeWiden" }, + { "LeftEyeBrow", "/LeftEyeBrow" }, + { "LeftEyeWiden", "/LeftEyeWiden" }, //{ "LeftEyeLower", "/LeftEyeLower" }, - //{ "LeftEyeBrow", "/LeftEyeBrow" }, { "RightEyeX", "/RightEyeX" }, { "RightEyeY", "/RightEyeY" }, { "RightEyeLid", "/RightEyeLid" }, - //{ "RightEyeWiden", "/RightEyeWiden" }, + { "RightEyeBrow", "/RightEyeBrow" }, + { "RightEyeWiden", "/RightEyeWiden" }, //{ "RightEyeLower", "/RightEyeLower" }, - //{ "RightEyeBrow", "/RightEyeBrow" }, }; public readonly Dictionary FaceExpressionMap = new() diff --git a/src/Baballonia/Utils/Utils.cs b/src/Baballonia/Utils/Utils.cs index 3e88f32f..3eca23f9 100644 --- a/src/Baballonia/Utils/Utils.cs +++ b/src/Baballonia/Utils/Utils.cs @@ -12,7 +12,7 @@ namespace Baballonia; public static class Utils { - public const int EyeRawExpressions = 6; + public const int EyeRawExpressions = 10; public const int FaceRawExpressions = 45; public const int FramesForEyeInference = 4; diff --git a/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs b/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs index b7394e49..62d2c370 100644 --- a/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs +++ b/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs @@ -42,12 +42,12 @@ public CalibrationViewModel(EyePipelineManager eyePipelineManager) [ new("LeftEyeLid"), new("RightEyeLid"), - //new ("LeftEyeWiden"), + new ("LeftEyeWiden"), //new ("LeftEyeLower"), - //new ("LeftEyeBrow"), - //new ("RightEyeWiden"), + new ("LeftEyeBrow"), + new ("RightEyeWiden"), //new ("RightEyeLower"), - //new ("RightEyeBrow"), + new ("RightEyeBrow"), ]; JawSettings = diff --git a/src/Baballonia/Views/HomePageView.axaml.cs b/src/Baballonia/Views/HomePageView.axaml.cs index 52d520c0..b67a80e0 100644 --- a/src/Baballonia/Views/HomePageView.axaml.cs +++ b/src/Baballonia/Views/HomePageView.axaml.cs @@ -19,6 +19,10 @@ public partial class HomePageView : ViewBase { Patterns = ["*.onnx"], }; + private static readonly FilePickerFileType RustTrainedModels = new("Rust-Trained Models") + { + Patterns = ["*.bin.gz"], + }; private bool _isLayoutUpdating; @@ -263,7 +267,7 @@ private async void EyeModelLoad(object? sender, RoutedEventArgs e) Title = "Select ONNX Model", AllowMultiple = false, SuggestedStartLocation = suggestedStartLocation, // Falls back to desktop if Models folder hasn't been created yet - FileTypeFilter = [OnnxAll] + FileTypeFilter = [OnnxAll, RustTrainedModels], })!; if (file.Count == 0) return; diff --git a/src/VRCFaceTracking.Baballonia/BabbleOSC.cs b/src/VRCFaceTracking.Baballonia/BabbleOSC.cs index 6c071998..10033f76 100644 --- a/src/VRCFaceTracking.Baballonia/BabbleOSC.cs +++ b/src/VRCFaceTracking.Baballonia/BabbleOSC.cs @@ -87,9 +87,9 @@ private void ListenLoop() case "/LeftEyeWiden": EyeExpressions[(int)ExpressionMapping.EyeLeftWiden] = value; break; - // case "/LeftEyeLower": - // EyeExpressions[(int)ExpressionMapping.EyeLeftLower] = value; - // break; + case "/LeftEyeLower": + EyeExpressions[(int)ExpressionMapping.EyeLeftLower] = value; + break; case "/LeftEyeBrow": EyeExpressions[(int)ExpressionMapping.EyeLeftSquint] = value; break; @@ -105,9 +105,9 @@ private void ListenLoop() case "/RightEyeWiden": EyeExpressions[(int)ExpressionMapping.EyeRightWiden] = value; break; - // case "/RightEyeLower": - // EyeExpressions[(int)ExpressionMapping.EyeRightLower] = value; - // break; + case "/RightEyeLower": + EyeExpressions[(int)ExpressionMapping.EyeRightLower] = value; + break; case "/RightEyeBrow": EyeExpressions[(int)ExpressionMapping.EyeRightSquint] = value; break; diff --git a/src/babble_trainer b/src/babble_trainer new file mode 160000 index 00000000..670c0b4b --- /dev/null +++ b/src/babble_trainer @@ -0,0 +1 @@ +Subproject commit 670c0b4b9f63d0fc54daf335d3f6e6e5c86ab07c From ecfc8de797caaa0bfccfb9abf1806242593931e7 Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Wed, 18 Mar 2026 22:17:45 +1100 Subject: [PATCH 02/13] Add some compatibility with older models --- src/Baballonia/Services/EyePipelineManager.cs | 27 ++++++++++- .../Inference/EyeProcessingPipeline.cs | 48 +++++++++++++++++-- 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/src/Baballonia/Services/EyePipelineManager.cs b/src/Baballonia/Services/EyePipelineManager.cs index f61de78c..75a03382 100644 --- a/src/Baballonia/Services/EyePipelineManager.cs +++ b/src/Baballonia/Services/EyePipelineManager.cs @@ -54,10 +54,33 @@ public void InitializePipeline() public async Task LoadInferenceAsync() { - await Task.Run(CreateInference); + var eyeModelName = _localSettings.ReadSetting("EyeHome_EyeModel"); + if (eyeModelName != null && eyeModelName.EndsWith(".bin.gz", StringComparison.OrdinalIgnoreCase)) + { + await Task.Run(LoadRustInference); + return; + } + + var inf = await Task.Run(LoadOnnxInference); + _pipeline.InferenceService = inf; + } + + private DefaultInferenceRunner LoadOnnxInference() + { + const string defaultEyeModelName = "eyeModel.onnx"; + var eyeModelName = _localSettings.ReadSetting("EyeHome_EyeModel", defaultEyeModelName); + var eyeModelPath = Path.Combine(AppContext.BaseDirectory, eyeModelName); + + if (File.Exists(eyeModelPath)) return _inferenceFactory.Create(eyeModelPath); + _logger.LogError("{} Does not exists, Loading default...", eyeModelPath); + + eyeModelName = defaultEyeModelName; + eyeModelPath = Path.Combine(AppContext.BaseDirectory, eyeModelName); + + return _inferenceFactory.Create(eyeModelPath); } - private void CreateInference() + private void LoadRustInference() { var eyeModelName = _localSettings.ReadSetting("EyeHome_EyeModel"); if (eyeModelName != null) diff --git a/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs b/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs index 24e2eeb6..dd1adc1a 100644 --- a/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs +++ b/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs @@ -1,19 +1,23 @@ -using Baballonia.Services.events; +using System; +using Baballonia.Services.events; using Baballonia.Services.Inference.Enums; -using System; using babble_model.Net.Sys; +using Microsoft.Extensions.Logging; using OpenCvSharp; namespace Baballonia.Services.Inference; public class EyeProcessingPipeline : DefaultProcessingPipeline, IDisposable { + private readonly ILogger _logger; private readonly IEyePipelineEventBus _eyePipelineEventBus; private readonly FastCorruptionDetector.FastCorruptionDetector _fastCorruptionDetector = new(); private readonly ImageCollector _imageCollector = new(); + private bool UseRustPipeline = false; - public EyeProcessingPipeline(IEyePipelineEventBus eyePipelineEventBus) + public EyeProcessingPipeline(ILogger logger, IEyePipelineEventBus eyePipelineEventBus) { + _logger = logger; _eyePipelineEventBus = eyePipelineEventBus; } @@ -34,6 +38,8 @@ public EyeProcessingPipeline(IEyePipelineEventBus eyePipelineEventBus) return errorMsg; } + UseRustPipeline = true; + return null; } } @@ -46,6 +52,7 @@ public EyeProcessingPipeline(IEyePipelineEventBus eyePipelineEventBus) { string errorMsg = new String((sbyte*)res.value.error_message); Console.WriteLine($"Inference error: {errorMsg}"); + NativeMethods.freeModelOutputResult(res); return null; } @@ -80,11 +87,44 @@ public EyeProcessingPipeline(IEyePipelineEventBus eyePipelineEventBus) if (collected == null) return null; - var inferenceResult = RunInference(collected); + float[]? inferenceResult; + + if (UseRustPipeline) + { + inferenceResult = RunInference(collected); + } + else + { + if (InferenceService == null) + return null; + + ImageConverter?.Convert(collected, InferenceService.GetInputTensor()); + + inferenceResult = InferenceService?.Run(); + } if (inferenceResult == null) return null; + if (inferenceResult.Length == 6) + { + // Older model with only look and blink, without eyebrow and eye wide. + // We need to convert it to the new format by adding default values for the missing expressions. + inferenceResult = new float[] + { + inferenceResult[0], // left pitch + inferenceResult[1], // left yaw + inferenceResult[2], // left lid + 0, // left eyebrow (default value) + 0, // left eye wide (default value) + inferenceResult[3], // right pitch + inferenceResult[4], // right yaw + inferenceResult[5], // right lid + 0, // right eyebrow (default value) + 0 // right eye wide (default value) + }; + } + if (Filter != null) { inferenceResult = Filter.Filter(inferenceResult); From c1fc2f7bc58d0b4569dd3ed2c361567a364ad096 Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Wed, 18 Mar 2026 22:18:54 +1100 Subject: [PATCH 03/13] Add the rust training interface --- .../Baballonia.Desktop.csproj | 1 + .../Calibration/OverlayCalibrationService.cs | 9 +- .../Calibration/RustTrainerService.cs | 85 +++++++++++++++++++ .../Calibration/TrainerService.cs | 2 +- src/Baballonia.Desktop/Program.cs | 2 +- .../Trainer/TrainerServiceTest.cs | 2 +- src/Baballonia/Contracts/ITrainerService.cs | 2 +- src/babble_trainer | 2 +- 8 files changed, 96 insertions(+), 9 deletions(-) create mode 100644 src/Baballonia.Desktop/Calibration/RustTrainerService.cs diff --git a/src/Baballonia.Desktop/Baballonia.Desktop.csproj b/src/Baballonia.Desktop/Baballonia.Desktop.csproj index ade07bf0..2adfaf20 100644 --- a/src/Baballonia.Desktop/Baballonia.Desktop.csproj +++ b/src/Baballonia.Desktop/Baballonia.Desktop.csproj @@ -58,6 +58,7 @@ + diff --git a/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs b/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs index 5f86fcca..a0ec0416 100644 --- a/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs +++ b/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs @@ -66,13 +66,14 @@ public void Dispose() await calibrationStep.ExecuteAsync(messageDispatcher, _tokenSource.Token); } - var srcPath = Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest.onnx"); + var srcPath = Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest"); var destPath = Path.Combine(Utils.ModelsDirectory, - $"tuned_temporal_eye_tracking_{DateTime.Now:yyyyMMdd_HHmmss}.onnx"); + $"tuned_temporal_eye_tracking_{DateTime.Now:yyyyMMdd_HHmmss}"); - File.Move(srcPath, destPath); + File.Move(srcPath + ".bin.gz", destPath + ".bin.gz"); + File.Move(srcPath + "_config.json", destPath + "_config.json"); - localSettingsService.SaveSetting("EyeHome_EyeModel", destPath); + localSettingsService.SaveSetting("EyeHome_EyeModel", destPath + ".bin.gz"); await eyePipelineManager.LoadInferenceAsync(); if (localSettingsService.ReadSetting("AppSettings_ShareEyeData")) diff --git a/src/Baballonia.Desktop/Calibration/RustTrainerService.cs b/src/Baballonia.Desktop/Calibration/RustTrainerService.cs new file mode 100644 index 00000000..42b913ff --- /dev/null +++ b/src/Baballonia.Desktop/Calibration/RustTrainerService.cs @@ -0,0 +1,85 @@ +using System; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using Baballonia.Contracts; +using babble_model.Net.Sys; +using Microsoft.Extensions.Logging; +using OverlaySDK.Packets; + +namespace Baballonia.Desktop.Calibration; + +public partial class RustTrainerService(ILogger logger) : ITrainerService +{ + private readonly object _lock = new(); + + public event Action? OnProgress; + + static event Action? GlobalProgress; + static TaskCompletionSource tcs = null; + + [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvCdecl) })] + static void HandleProgress(TrainingDataCallback data) + { + Console.WriteLine($"Recieved {data.callback_type}: {data.low}/{data.high} ({data.loss})"); + TrainerProgressReportPacket progress; + if (data.callback_type == CallbackType.Batch) + { + progress = new TrainerProgressReportPacket("Batch", data.low, data.high, data.loss); + } else if (data.callback_type == CallbackType.Epoch) + { + progress = new TrainerProgressReportPacket("Epoch", data.low, data.high, data.loss); + } + else + { + tcs.TrySetResult(true); + return; + } + + GlobalProgress?.Invoke(progress); + } + + unsafe void CallTrainer(string usercalbinPath, string outputfilePath) + { + var userCalBytes = System.Text.Encoding.UTF8.GetBytes(usercalbinPath); + fixed (byte* userCalBytesPtr = userCalBytes) + { + var outputFileBytes = System.Text.Encoding.UTF8.GetBytes(outputfilePath); + fixed (byte* outputFileBytesPtr = outputFileBytes) + { + NativeMethods.trainModel(userCalBytesPtr, outputFileBytesPtr, &HandleProgress); + } + } + } + + public async Task RunTraining(string usercalbinPath, string outputfilePath) + { + if (!File.Exists(usercalbinPath)) + throw new FileNotFoundException(usercalbinPath + " not found"); + + + lock (_lock) + { + tcs = new TaskCompletionSource(); + GlobalProgress += OnProgress; + } + await Task.Run(() => CallTrainer(usercalbinPath, outputfilePath)); + } + + public Task WaitAsync() + { + return tcs.Task; + } + + public void Dispose() + { + lock (_lock) + { + } + } +} diff --git a/src/Baballonia.Desktop/Calibration/TrainerService.cs b/src/Baballonia.Desktop/Calibration/TrainerService.cs index 50378bbd..df4aa5e6 100644 --- a/src/Baballonia.Desktop/Calibration/TrainerService.cs +++ b/src/Baballonia.Desktop/Calibration/TrainerService.cs @@ -81,7 +81,7 @@ void NewLineEventHandler(object sender, DataReceivedEventArgs dataReceivedEventA return; } - public void RunTraining(string usercalbinPath, string outputfilePath) + public async Task RunTraining(string usercalbinPath, string outputfilePath) { if (!File.Exists(usercalbinPath)) throw new FileNotFoundException(usercalbinPath + " not found"); diff --git a/src/Baballonia.Desktop/Program.cs b/src/Baballonia.Desktop/Program.cs index c187d657..6cecee0b 100644 --- a/src/Baballonia.Desktop/Program.cs +++ b/src/Baballonia.Desktop/Program.cs @@ -44,7 +44,7 @@ public static int Main(string[] args) App.RegisterPlatformSpecificServices(collection => { collection.AddSingleton(); - collection.AddSingleton(); + collection.AddSingleton(); collection.AddSingleton(); collection.AddSingleton(); }); diff --git a/src/Baballonia.Tests/Trainer/TrainerServiceTest.cs b/src/Baballonia.Tests/Trainer/TrainerServiceTest.cs index a50520e1..8cf2d3f1 100644 --- a/src/Baballonia.Tests/Trainer/TrainerServiceTest.cs +++ b/src/Baballonia.Tests/Trainer/TrainerServiceTest.cs @@ -17,7 +17,7 @@ public async Task Test() var log = factory.CreateLogger(); TrainerService service = new TrainerService(log); - service.RunTraining("test.bin", "model.onnx"); + await service.RunTraining("test.bin", "model.onnx"); await service.WaitAsync(); } diff --git a/src/Baballonia/Contracts/ITrainerService.cs b/src/Baballonia/Contracts/ITrainerService.cs index b84b1a83..33b2180a 100644 --- a/src/Baballonia/Contracts/ITrainerService.cs +++ b/src/Baballonia/Contracts/ITrainerService.cs @@ -7,7 +7,7 @@ namespace Baballonia.Contracts; public interface ITrainerService : IDisposable { public event Action? OnProgress; - public void RunTraining(string usercalbinPath, string outputfilePath); + public Task RunTraining(string usercalbinPath, string outputfilePath); public Task WaitAsync(); } diff --git a/src/babble_trainer b/src/babble_trainer index 670c0b4b..e2d2972c 160000 --- a/src/babble_trainer +++ b/src/babble_trainer @@ -1 +1 @@ -Subproject commit 670c0b4b9f63d0fc54daf335d3f6e6e5c86ab07c +Subproject commit e2d2972c4482d491638515c0d2c4db35a789a70b From 3b516719ebcc4deee49765f10cb18e3a80dd6a52 Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Sat, 21 Mar 2026 08:26:48 +1100 Subject: [PATCH 04/13] Correctly the calibration output file and some parameter send fixes --- .../Calibration/ICalibrationRoutine.cs | 4 ++-- src/Baballonia/Services/CalibrationService.cs | 4 ++-- src/Baballonia/Services/ParameterSenderService.cs | 8 ++++---- .../SplitViewPane/CalibrationViewModel.cs | 14 +++++++------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs index 53a83228..debfd0c0 100644 --- a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs +++ b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs @@ -297,8 +297,8 @@ public async Task ExecuteAsync(OverlayMessageDispatcher dispatcher, Cancellation dispatcher.Dispatch(new RunVariableLenghtRoutinePacket(Name, TimeSpan.FromSeconds(120))); var onProgressHandler = (TrainerProgressReportPacket packet) => { dispatcher.Dispatch(packet); }; overlayTrainer.OnProgress += onProgressHandler; - overlayTrainer.RunTraining(Path.Combine(Utils.ModelDataDirectory, "user_cal.bin"), - Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest.onnx")); + await overlayTrainer.RunTraining(Path.Combine(Utils.ModelDataDirectory, "user_cal.bin"), + Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest")); await overlayTrainer.WaitAsync(); overlayTrainer.OnProgress -= onProgressHandler; diff --git a/src/Baballonia/Services/CalibrationService.cs b/src/Baballonia/Services/CalibrationService.cs index 2695be4b..d6969cc8 100644 --- a/src/Baballonia/Services/CalibrationService.cs +++ b/src/Baballonia/Services/CalibrationService.cs @@ -20,13 +20,13 @@ public class CalibrationService : ICalibrationService { { "LeftEyeLid", "/LeftEyeLid" }, { "LeftEyeWiden", "/LeftEyeWiden" }, - // { "LeftEyeLower", "/LeftEyeLower" }, + { "LeftEyeLower", "/LeftEyeLower" }, { "LeftEyeBrow", "/LeftEyeBrow" }, { "RightEyeX", "/RightEyeX" }, { "RightEyeY", "/RightEyeY" }, { "RightEyeLid", "/RightEyeLid" }, { "RightEyeWiden", "/RightEyeWiden" }, - // { "RightEyeLower", "/RightEyeLower" }, + { "RightEyeLower", "/RightEyeLower" }, { "RightEyeBrow", "/RightEyeBrow" }, { "CheekPuffLeft", "/cheekPuffLeft" }, { "CheekPuffRight", "/cheekPuffRight" }, diff --git a/src/Baballonia/Services/ParameterSenderService.cs b/src/Baballonia/Services/ParameterSenderService.cs index 9dcc88c1..f438d0a5 100644 --- a/src/Baballonia/Services/ParameterSenderService.cs +++ b/src/Baballonia/Services/ParameterSenderService.cs @@ -32,15 +32,15 @@ public class ParameterSenderService : BackgroundService { "LeftEyeX", "/LeftEyeX" }, { "LeftEyeY", "/LeftEyeY" }, { "LeftEyeLid", "/LeftEyeLid" }, - { "LeftEyeBrow", "/LeftEyeBrow" }, + { "LeftEyeLower", "/LeftEyeLower" }, { "LeftEyeWiden", "/LeftEyeWiden" }, - //{ "LeftEyeLower", "/LeftEyeLower" }, + //{ "LeftEyeBrow", "/LeftEyeBrow" }, { "RightEyeX", "/RightEyeX" }, { "RightEyeY", "/RightEyeY" }, { "RightEyeLid", "/RightEyeLid" }, - { "RightEyeBrow", "/RightEyeBrow" }, + { "RightEyeLower", "/RightEyeLower" }, { "RightEyeWiden", "/RightEyeWiden" }, - //{ "RightEyeLower", "/RightEyeLower" }, + //{ "RightEyeBrow", "/RightEyeBrow" }, }; public readonly Dictionary FaceExpressionMap = new() diff --git a/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs b/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs index 62d2c370..5e3b505c 100644 --- a/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs +++ b/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs @@ -127,14 +127,14 @@ public CalibrationViewModel(EyePipelineManager eyePipelineManager) { "LeftEyeX", 0 }, { "LeftEyeY", 1 }, { "LeftEyeLid", 2 }, - //{ "LeftEyeWiden", }, - //{ "LeftEyeLower", }, + { "LeftEyeLower", 3 }, + { "LeftEyeWiden", 4 }, //{ "LeftEyeBrow", }, - { "RightEyeX", 3 }, - { "RightEyeY", 4 }, - { "RightEyeLid", 5 }, - //{ "RightEyeWiden", }, - //{ "RightEyeLower", }, + { "RightEyeX", 5 }, + { "RightEyeY", 6 }, + { "RightEyeLid", 7 }, + { "RightEyeLower", 8 }, + { "RightEyeWiden", 9 }, //{ "RightEyeBrow", }, }; From f319bc3f06ace4eb065e6cc7aeeada754028dfa6 Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Sat, 21 Mar 2026 08:29:24 +1100 Subject: [PATCH 05/13] Update trainer version --- src/babble_trainer | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/babble_trainer b/src/babble_trainer index e2d2972c..2ece5ec7 160000 --- a/src/babble_trainer +++ b/src/babble_trainer @@ -1 +1 @@ -Subproject commit e2d2972c4482d491638515c0d2c4db35a789a70b +Subproject commit 2ece5ec79e0c62772f3acdac4b4bd1a216a218b0 From e1bb2b13c77bdc7c8358609d5c575b62a784397d Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Sat, 21 Mar 2026 08:32:50 +1100 Subject: [PATCH 06/13] Add a calibration option that only runs training on existing data --- .../Calibration/ICalibrationRoutine.cs | 12 ++++++++++++ .../Calibration/OverlayCalibrationService.cs | 1 + src/Baballonia/Helpers/CalibrationRoutine.cs | 1 + src/Baballonia/Views/HomePageView.axaml | 5 +++++ 4 files changed, 19 insertions(+) diff --git a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs index 53a83228..8c6103ef 100644 --- a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs +++ b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs @@ -453,4 +453,16 @@ public IEnumerable BlinkCalibration() return steps; } + + public IEnumerable TrainCalibration() + { + List steps = + [ + new TrainerCalibrationStep(trainer), + new CommandDispatchStep("close") + + ]; + + return steps; + } } diff --git a/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs b/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs index 5f86fcca..6b1c3129 100644 --- a/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs +++ b/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs @@ -59,6 +59,7 @@ public void Dispose() CalibrationRoutine.Routines.BasicCalibrationNoTutorial => eyeCalibration.BasicAllCalibrationQuick(), CalibrationRoutine.Routines.GazeOnly => eyeCalibration.GazeCalibration(), CalibrationRoutine.Routines.BlinkOnly => eyeCalibration.BlinkCalibration(), + CalibrationRoutine.Routines.TrainOnly => eyeCalibration.TrainCalibration(), _ => eyeCalibration.BasicAllCalibration() }; foreach (var calibrationStep in steps) diff --git a/src/Baballonia/Helpers/CalibrationRoutine.cs b/src/Baballonia/Helpers/CalibrationRoutine.cs index c0e5999c..a2ea4daa 100644 --- a/src/Baballonia/Helpers/CalibrationRoutine.cs +++ b/src/Baballonia/Helpers/CalibrationRoutine.cs @@ -12,6 +12,7 @@ public enum Routines BasicCalibrationNoTutorial, GazeOnly, BlinkOnly, + TrainOnly, } public static readonly Dictionary Map = Enum.GetValues().ToDictionary(i => i.ToString(), i => i); /* diff --git a/src/Baballonia/Views/HomePageView.axaml b/src/Baballonia/Views/HomePageView.axaml index aa56670e..b5d23f53 100644 --- a/src/Baballonia/Views/HomePageView.axaml +++ b/src/Baballonia/Views/HomePageView.axaml @@ -99,6 +99,11 @@ + + + + + From b06785092e7eeea579b8750a152ac0348254cf5f Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Sat, 21 Mar 2026 09:06:38 +1100 Subject: [PATCH 07/13] Add rust trainer to github action --- .github/workflows/build.yml | 16 ++++++++++++++++ src/babble_trainer | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 87dd51dc..64c46852 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -49,6 +49,14 @@ jobs: - name: Download dependencies run: ./download_dependencies.sh + - name: Setup Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + + - name: Build Rust trainer + run: | + cd src/babble_trainer + cargo build --release + #- name: Install dependencies # run: | # dotnet tool install -g vpk @@ -89,6 +97,14 @@ jobs: shell: pwsh run: ./download_dependencies.ps1 + - name: Setup Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + + - name: Build Rust trainer + run: | + cd src/babble_trainer + cargo build --release + - name: Build project run: | cd src/Baballonia.Desktop diff --git a/src/babble_trainer b/src/babble_trainer index 2ece5ec7..f91b37a4 160000 --- a/src/babble_trainer +++ b/src/babble_trainer @@ -1 +1 @@ -Subproject commit 2ece5ec79e0c62772f3acdac4b4bd1a216a218b0 +Subproject commit f91b37a4a0590a5525d3a48f662be5dbc9673dca From 84d24479a271de0c37b9373fd90457151b086a5a Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Sat, 21 Mar 2026 09:16:17 +1100 Subject: [PATCH 08/13] Update babble trainer and set correct rust-src dir --- .github/workflows/build.yml | 4 ++++ src/babble_trainer | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 64c46852..43891b45 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -51,6 +51,8 @@ jobs: - name: Setup Rust toolchain uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + rust-src-dir: src/babble_trainer - name: Build Rust trainer run: | @@ -99,6 +101,8 @@ jobs: - name: Setup Rust toolchain uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + rust-src-dir: src/babble_trainer - name: Build Rust trainer run: | diff --git a/src/babble_trainer b/src/babble_trainer index f91b37a4..c6e13d5f 160000 --- a/src/babble_trainer +++ b/src/babble_trainer @@ -1 +1 @@ -Subproject commit f91b37a4a0590a5525d3a48f662be5dbc9673dca +Subproject commit c6e13d5f9f008f60697396593c8fb55ad469b4b0 From eb158afb9a31c84258bf7e40f6bd1f9b16e4e95f Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Sat, 21 Mar 2026 09:42:13 +1100 Subject: [PATCH 09/13] Fix issues in Rust Trainer --- .github/workflows/build.yml | 4 ++-- .../Calibration/RustTrainerService.cs | 16 +++++++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 43891b45..1f5585ae 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -57,7 +57,7 @@ jobs: - name: Build Rust trainer run: | cd src/babble_trainer - cargo build --release + cargo build --release --lib #- name: Install dependencies # run: | @@ -107,7 +107,7 @@ jobs: - name: Build Rust trainer run: | cd src/babble_trainer - cargo build --release + cargo build --release --lib - name: Build project run: | diff --git a/src/Baballonia.Desktop/Calibration/RustTrainerService.cs b/src/Baballonia.Desktop/Calibration/RustTrainerService.cs index 3caa4807..8c621135 100644 --- a/src/Baballonia.Desktop/Calibration/RustTrainerService.cs +++ b/src/Baballonia.Desktop/Calibration/RustTrainerService.cs @@ -14,19 +14,19 @@ namespace Baballonia.Desktop.Calibration; -public partial class RustTrainerService(ILogger logger) : ITrainerService +public partial class RustTrainerService : ITrainerService { private readonly object _lock = new(); public event Action? OnProgress; static event Action? GlobalProgress; - static TaskCompletionSource tcs = null; + static TaskCompletionSource? tcs; [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvCdecl) })] static void HandleProgress(TrainingDataCallback data) { - logger.LogInformation($"Recieved {data.callback_type}: {data.low}/{data.high} ({data.loss})"); + Console.WriteLine($"Recieved {data.callback_type}: {data.low}/{data.high} ({data.loss})"); TrainerProgressReportPacket progress; if (data.callback_type == CallbackType.Batch) { @@ -37,7 +37,13 @@ static void HandleProgress(TrainingDataCallback data) } else { - tcs.TrySetResult(true); + if (tcs != null) + { + tcs.TrySetResult(true); + } else + { + Console.WriteLine("tcs is null when trying to set result"); + } return; } @@ -73,7 +79,7 @@ public async Task RunTraining(string usercalbinPath, string outputfilePath) public Task WaitAsync() { - return tcs.Task; + return tcs != null ? tcs.Task : Task.CompletedTask; } public void Dispose() From 4e400ff3685aa424002dc381617b349ebd4c3dab Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Sat, 21 Mar 2026 10:33:44 +1100 Subject: [PATCH 10/13] Update babble_trainer --- src/babble_trainer | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/babble_trainer b/src/babble_trainer index c6e13d5f..3372b4bf 160000 --- a/src/babble_trainer +++ b/src/babble_trainer @@ -1 +1 @@ -Subproject commit c6e13d5f9f008f60697396593c8fb55ad469b4b0 +Subproject commit 3372b4bf0890f539e458ab75e49187eb9a8cbab1 From 9a663b6d2c02f466bdb3cc39d1c17f5ddab0c66b Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Sat, 21 Mar 2026 12:00:45 +1100 Subject: [PATCH 11/13] Update trainer to fix linux so inclusion --- src/babble_trainer | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/babble_trainer b/src/babble_trainer index 3372b4bf..50313bd7 160000 --- a/src/babble_trainer +++ b/src/babble_trainer @@ -1 +1 @@ -Subproject commit 3372b4bf0890f539e458ab75e49187eb9a8cbab1 +Subproject commit 50313bd71b40473a99b0facc3d9722d55e05b9fd From 4b5cfc45e8b47aed43f9b989f1739e71a1293138 Mon Sep 17 00:00:00 2001 From: dfgHiatus <51272212+dfgHiatus@users.noreply.github.com> Date: Sat, 21 Mar 2026 15:36:51 -0500 Subject: [PATCH 12/13] Update ICalibrationRoutine.cs --- .../Calibration/ICalibrationRoutine.cs | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs index debfd0c0..34c36b10 100644 --- a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs +++ b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs @@ -354,23 +354,23 @@ public IEnumerable BasicAllCalibration() TimeSpan.FromSeconds(20), lid: 0 ), - // new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(10)), - // eyeCaptureStepFactory.Create("widen", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), widen: 1, lid: 1), - // - // new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(10)), - // eyeCaptureStepFactory.Create("squint", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), squint: 1, lid: 1), - // - // new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(10)), - // eyeCaptureStepFactory.Create("brow", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), browAngry: 1, lid: 1), + new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create("widen", + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), widen: 1, lid: 1), + + new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create("squint", + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), squint: 1, lid: 1), + + new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create("brow", + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), browAngry: 1, lid: 1), // steps.Add(new BaseTutorialStep("covergencetutorial")); // steps.Add(_eyeCaptureStepFactory.Create("covergence", // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_WHATEVER_NOT_IMPLEMENTED)); - // new MergeBinsStep("gaze.bin", "blink.bin", "widen.bin", "squint.bin", "brow.bin"), - new MergeBinsStep("gaze.bin", "blink.bin"), + new MergeBinsStep("gaze.bin", "blink.bin", "widen.bin", "squint.bin", "brow.bin"), + //new MergeBinsStep("gaze.bin", "blink.bin"), new TrainerCalibrationStep(trainer), new CommandDispatchStep("close") @@ -394,20 +394,20 @@ public IEnumerable BasicAllCalibrationQuick() TimeSpan.FromSeconds(20) ), - //new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(4)), - //eyeCaptureStepFactory.Create("widen", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), - // - //new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(4)), - //eyeCaptureStepFactory.Create("squint", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), - // - //new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(4)), - //eyeCaptureStepFactory.Create("brow", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), - - //new MergeBinsStep("gaze.bin", "blink.bin", "widen.bin", "squint.bin", "brow.bin"), - new MergeBinsStep("gaze.bin", "blink.bin"), + new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(4)), + eyeCaptureStepFactory.Create("widen", + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), + + new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(4)), + eyeCaptureStepFactory.Create("squint", + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), + + new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(4)), + eyeCaptureStepFactory.Create("brow", + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), + + new MergeBinsStep("gaze.bin", "blink.bin", "widen.bin", "squint.bin", "brow.bin"), + //new MergeBinsStep("gaze.bin", "blink.bin"), new TrainerCalibrationStep(trainer), new CommandDispatchStep("close") From 062003aed8ff5b534fb74e4eb8eaa7a84e6ef5c8 Mon Sep 17 00:00:00 2001 From: acmdf <233685277+acmdf@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:36:29 +1100 Subject: [PATCH 13/13] Add page to choose from different samples --- .../Calibration/AndroidOverlayTrainerCombo.cs | 3 +- .../Calibration/FrameCollector.cs | 4 +- .../Calibration/ICalibrationRoutine.cs | 150 +++++++++--------- .../Calibration/OverlayCalibrationService.cs | 63 ++++---- src/Baballonia/App.axaml.cs | 2 + src/Baballonia/Baballonia.csproj | 3 + src/Baballonia/Contracts/IVROverlay.cs | 3 +- .../Converters/FileNameConverter.cs | 17 ++ src/Baballonia/Helpers/CalibrationRoutine.cs | 5 +- src/Baballonia/Utils/Utils.cs | 4 + src/Baballonia/ViewLocator.cs | 1 + src/Baballonia/ViewModels/MainViewModel.cs | 1 + .../SplitViewPane/EyeTrainingViewModel.cs | 147 +++++++++++++++++ .../SplitViewPane/HomePageViewModel.cs | 29 ++-- src/Baballonia/Views/EyeTrainingView.axaml | 114 +++++++++++++ src/Baballonia/Views/EyeTrainingView.axaml.cs | 34 ++++ src/Baballonia/Views/HomePageView.axaml | 16 -- 17 files changed, 458 insertions(+), 138 deletions(-) create mode 100644 src/Baballonia/Converters/FileNameConverter.cs create mode 100644 src/Baballonia/ViewModels/SplitViewPane/EyeTrainingViewModel.cs create mode 100644 src/Baballonia/Views/EyeTrainingView.axaml create mode 100644 src/Baballonia/Views/EyeTrainingView.axaml.cs diff --git a/src/Baballonia.Android/Calibration/AndroidOverlayTrainerCombo.cs b/src/Baballonia.Android/Calibration/AndroidOverlayTrainerCombo.cs index 471701b0..908bce4f 100644 --- a/src/Baballonia.Android/Calibration/AndroidOverlayTrainerCombo.cs +++ b/src/Baballonia.Android/Calibration/AndroidOverlayTrainerCombo.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Threading.Tasks; using Baballonia.Contracts; using Baballonia.Helpers; @@ -11,7 +12,7 @@ public Task EyeTrackingCalibrationRequested(string calibrationRoutine) return Task.CompletedTask; } - public Task<(bool success, string status)> EyeTrackingCalibrationRequested(CalibrationRoutine.Routines calibrationRoutine) + public Task<(bool success, string status)> EyeTrackingCalibrationRequested(CalibrationRoutine.Routines calibrationRoutine, List args) { return Task.FromResult((true, "Not Supported")); } diff --git a/src/Baballonia.Desktop/Calibration/FrameCollector.cs b/src/Baballonia.Desktop/Calibration/FrameCollector.cs index 3249d071..75a852c9 100644 --- a/src/Baballonia.Desktop/Calibration/FrameCollector.cs +++ b/src/Baballonia.Desktop/Calibration/FrameCollector.cs @@ -80,7 +80,7 @@ public void WriteBin(string path) _frames.Clear(); } - CaptureBin.IO.CaptureBin.WriteAll(Path.Combine(Utils.ModelDataDirectory, path), copy); + CaptureBin.IO.CaptureBin.WriteAll(path, copy); } } public class BinCollector(uint headerFlags) @@ -149,6 +149,6 @@ public void WriteBin(string path) _frames.Clear(); } - CaptureBin.IO.CaptureBin.WriteAll(Path.Combine(Utils.ModelDataDirectory, path), copy); + CaptureBin.IO.CaptureBin.WriteAll(path, copy); } } diff --git a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs index 6db74573..bd0db652 100644 --- a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs +++ b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs @@ -159,7 +159,7 @@ public void Dispose() } } -public class GazeCaptureStep(IEyePipelineEventBus bus, TimeSpan time) : BasePositionalAwareEyeCaptureStep(bus, "gaze", +public class GazeCaptureStep(IEyePipelineEventBus bus, TimeSpan time, string fileName) : BasePositionalAwareEyeCaptureStep(bus, "gaze", fileName, CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_IN_MOVEMENT | CaptureFlags.FLAG_VERSION_BIT1 | @@ -168,7 +168,7 @@ public class GazeCaptureStep(IEyePipelineEventBus bus, TimeSpan time) : BasePosi private Stopwatch _posDataTimer = new(); private readonly TimeSpan _posDataTimeout = TimeSpan.FromSeconds(0.2); - public GazeCaptureStep(IEyePipelineEventBus bus) : this(bus, TimeSpan.FromSeconds(120)) + public GazeCaptureStep(IEyePipelineEventBus bus, string fileName) : this(bus, TimeSpan.FromSeconds(120), fileName) { } @@ -204,6 +204,7 @@ public override void OnNewEyeFrame(EyePipelineEvents.NewTransformedFrameEvent fr public class BasePositionalAwareEyeCaptureStep( IEyePipelineEventBus eyePipelineEvent, string name, + string fileName, uint flags, TimeSpan time) : PositionalAwareCaptureStep(name, flags, time) @@ -224,13 +225,14 @@ public override async Task ExecuteAsync(OverlayMessageDispatcher dispatcher, Can if (ct.IsCancellationRequested) return; - PositionalBinCollector.WriteBin(Name + ".bin"); + PositionalBinCollector.WriteBin(fileName); } } public class BaseEyeCaptureStep( IEyePipelineEventBus eyePipelineEvent, string name, + string fileName, uint flags, TimeSpan time, float lid = 0, @@ -257,7 +259,7 @@ public override async Task ExecuteAsync(OverlayMessageDispatcher dispatcher, Can if (ct.IsCancellationRequested) return; - BinCollector.WriteBin(Name + ".bin"); + BinCollector.WriteBin(fileName); } public override Frame AddFrame(Mat[] images) @@ -307,14 +309,14 @@ await overlayTrainer.RunTraining(Path.Combine(Utils.ModelDataDirectory, "user_ca public class EyeCaptureStepFactory(IEyePipelineEventBus eyePipelineEvent) { - public BaseEyeCaptureStep Create(string name, uint flags, TimeSpan time, + public BaseEyeCaptureStep Create(string name, string fileName, uint flags, TimeSpan time, float lid = 0, float browRaise = 0, float browAngry = 0, float widen = 0, float squint = 0, float dilate = 0) => - new(eyePipelineEvent, name, flags, time, lid, browRaise, browAngry, widen, squint, dilate); + new(eyePipelineEvent, name, fileName, flags, time, lid, browRaise, browAngry, widen, squint, dilate); } public class MergeBinsStep(params string[] binNames) : ICalibrationStep @@ -330,8 +332,7 @@ public Task ExecuteAsync(OverlayMessageDispatcher dispatcher, CancellationToken private static void MergeBins(string result, params string[] inputs) { var resultPath = Path.Combine(Utils.ModelDataDirectory, result); - var inputPaths = inputs.Select(i => Path.Combine(Utils.ModelDataDirectory, i)).ToArray(); - CaptureBin.IO.CaptureBin.Concatenate(resultPath, inputPaths); + CaptureBin.IO.CaptureBin.Concatenate(resultPath, inputs); } } @@ -340,14 +341,69 @@ public class EyeCalibration( ITrainerService trainer, IEyePipelineEventBus eyePipelineEventBus) { + public IEnumerable GetCalibrationStep(string stepName, string fileName) + { + return stepName switch + { + "gaze" => new List + { + new BaseTutorialStep("gazetutorial"), + new GazeCaptureStep(eyePipelineEventBus, fileName), + new CommandDispatchStep("close") + }, + "blink" => new List + { + new BaseTutorialStep("blinktutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create(stepName, fileName, + CaptureFlags.FLAG_GOOD_DATA | + CaptureFlags.FLAG_IN_MOVEMENT | + CaptureFlags.FLAG_VERSION_BIT1, + TimeSpan.FromSeconds(20) + ), + new CommandDispatchStep("close") + }, + "widen" => new List + { + new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create(stepName,fileName, + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), widen: 1, lid: 1), + new CommandDispatchStep("close") + }, + "squint" => new List + { + new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create(stepName,fileName, + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), squint: 1, lid: 1), + new CommandDispatchStep("close") + }, + "brow" => new List + { + new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create(stepName, fileName, + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), browAngry: 1, lid: 1), + new CommandDispatchStep("close") + } + }; + } + + public IEnumerable TrainCalibration(List args) + { + return new List + { + new MergeBinsStep(args.ToArray()), + new TrainerCalibrationStep(trainer), + new CommandDispatchStep("close") + }; + } + public IEnumerable BasicAllCalibration() { List steps = [ new BaseTutorialStep("gazetutorial"), - new GazeCaptureStep(eyePipelineEventBus), + new GazeCaptureStep(eyePipelineEventBus, Path.Combine(Utils.ModelDataDirectory, "gaze.bin")), new BaseTutorialStep("blinktutorial", TimeSpan.FromSeconds(10)), - eyeCaptureStepFactory.Create("blink", + eyeCaptureStepFactory.Create("blink", Path.Combine(Utils.ModelDataDirectory, "blink.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_IN_MOVEMENT | CaptureFlags.FLAG_VERSION_BIT1, @@ -355,22 +411,22 @@ public IEnumerable BasicAllCalibration() ), new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(10)), - eyeCaptureStepFactory.Create("widen", + eyeCaptureStepFactory.Create("widen", Path.Combine(Utils.ModelDataDirectory, "widen.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), widen: 1, lid: 1), new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(10)), - eyeCaptureStepFactory.Create("squint", + eyeCaptureStepFactory.Create("squint", Path.Combine(Utils.ModelDataDirectory, "squint.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), squint: 1, lid: 1), new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(10)), - eyeCaptureStepFactory.Create("brow", + eyeCaptureStepFactory.Create("brow", Path.Combine(Utils.ModelDataDirectory, "brow.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), browAngry: 1, lid: 1), // steps.Add(new BaseTutorialStep("covergencetutorial")); // steps.Add(_eyeCaptureStepFactory.Create("covergence", // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_WHATEVER_NOT_IMPLEMENTED)); - new MergeBinsStep("gaze.bin", "blink.bin", "widen.bin", "squint.bin", "brow.bin"), - //new MergeBinsStep("gaze.bin", "blink.bin"), + new MergeBinsStep(Path.Combine(Utils.ModelDataDirectory, "gaze.bin"), Path.Combine(Utils.ModelDataDirectory, "blink.bin"), Path.Combine(Utils.ModelDataDirectory, "widen.bin"), Path.Combine(Utils.ModelDataDirectory, "squint.bin"), Path.Combine(Utils.ModelDataDirectory, "brow.bin")), + //new MergeBinsStep(Path.Combine(Utils.ModelDataDirectory, "gaze.bin"), Path.Combine(Utils.ModelDataDirectory, "blink.bin")), new TrainerCalibrationStep(trainer), new CommandDispatchStep("close") @@ -384,9 +440,9 @@ public IEnumerable BasicAllCalibrationQuick() List steps = [ new BaseTutorialStep("gazetutorialshort", TimeSpan.FromSeconds(5)), - new GazeCaptureStep(eyePipelineEventBus, TimeSpan.FromSeconds(10)), + new GazeCaptureStep(eyePipelineEventBus, Path.Combine(Utils.ModelDataDirectory, "gaze.bin")), new BaseTutorialStep("blinktutorial", TimeSpan.FromSeconds(4)), - eyeCaptureStepFactory.Create("blink", + eyeCaptureStepFactory.Create("blink", Path.Combine(Utils.ModelDataDirectory, "blink.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_IN_MOVEMENT | CaptureFlags.FLAG_VERSION_BIT1 | @@ -395,69 +451,19 @@ public IEnumerable BasicAllCalibrationQuick() ), new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(4)), - eyeCaptureStepFactory.Create("widen", + eyeCaptureStepFactory.Create("widen", Path.Combine(Utils.ModelDataDirectory, "widen.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(4)), - eyeCaptureStepFactory.Create("squint", + eyeCaptureStepFactory.Create("squint", Path.Combine(Utils.ModelDataDirectory, "squint.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(4)), - eyeCaptureStepFactory.Create("brow", + eyeCaptureStepFactory.Create("brow", Path.Combine(Utils.ModelDataDirectory, "brow.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), - new MergeBinsStep("gaze.bin", "blink.bin", "widen.bin", "squint.bin", "brow.bin"), - //new MergeBinsStep("gaze.bin", "blink.bin"), - new TrainerCalibrationStep(trainer), - new CommandDispatchStep("close") - - ]; - - return steps; - } - - public IEnumerable GazeCalibration() - { - List steps = - [ - new BaseTutorialStep("gazetutorialshort", TimeSpan.FromSeconds(5)), - new GazeCaptureStep(eyePipelineEventBus), - - new MergeBinsStep("gaze.bin", "blink.bin"), - new TrainerCalibrationStep(trainer), - new CommandDispatchStep("close") - - ]; - - return steps; - } - - public IEnumerable BlinkCalibration() - { - List steps = - [ - new BaseTutorialStep("blinktutorial", TimeSpan.FromSeconds(4)), - eyeCaptureStepFactory.Create("blink", - CaptureFlags.FLAG_GOOD_DATA | - CaptureFlags.FLAG_IN_MOVEMENT | - CaptureFlags.FLAG_VERSION_BIT1 | - CaptureFlags.FLAG_ROUTINE_BIT1, - TimeSpan.FromSeconds(20) - ), - - new MergeBinsStep("gaze.bin", "blink.bin"), - new TrainerCalibrationStep(trainer), - new CommandDispatchStep("close") - - ]; - - return steps; - } - - public IEnumerable TrainCalibration() - { - List steps = - [ + new MergeBinsStep(Path.Combine(Utils.ModelDataDirectory, "gaze.bin"), Path.Combine(Utils.ModelDataDirectory, "blink.bin"), Path.Combine(Utils.ModelDataDirectory, "widen.bin"), Path.Combine(Utils.ModelDataDirectory, "squint.bin"), Path.Combine(Utils.ModelDataDirectory, "brow.bin")), + //new MergeBinsStep(Path.Combine(Utils.ModelDataDirectory, "gaze.bin"), Path.Combine(Utils.ModelDataDirectory, "blink.bin")), new TrainerCalibrationStep(trainer), new CommandDispatchStep("close") diff --git a/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs b/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs index 448eb030..623462e0 100644 --- a/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs +++ b/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs @@ -5,6 +5,7 @@ using OverlaySDK; using System; using System.IO; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -28,7 +29,7 @@ public void Dispose() } public async Task<(bool success, string status)> EyeTrackingCalibrationRequested( - CalibrationRoutine.Routines routine) + CalibrationRoutine.Routines routine, List args) { if (!overlayProgram.CanStart()) { @@ -57,9 +58,8 @@ public void Dispose() { CalibrationRoutine.Routines.BasicCalibration => eyeCalibration.BasicAllCalibration(), CalibrationRoutine.Routines.BasicCalibrationNoTutorial => eyeCalibration.BasicAllCalibrationQuick(), - CalibrationRoutine.Routines.GazeOnly => eyeCalibration.GazeCalibration(), - CalibrationRoutine.Routines.BlinkOnly => eyeCalibration.BlinkCalibration(), - CalibrationRoutine.Routines.TrainOnly => eyeCalibration.TrainCalibration(), + CalibrationRoutine.Routines.TutorialStep => eyeCalibration.GetCalibrationStep(args[0], args[1]), + CalibrationRoutine.Routines.TrainModel => eyeCalibration.TrainCalibration(args), _ => eyeCalibration.BasicAllCalibration() }; foreach (var calibrationStep in steps) @@ -67,32 +67,37 @@ public void Dispose() await calibrationStep.ExecuteAsync(messageDispatcher, _tokenSource.Token); } - var srcPath = Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest"); - var destPath = Path.Combine(Utils.ModelsDirectory, - $"tuned_temporal_eye_tracking_{DateTime.Now:yyyyMMdd_HHmmss}"); - - if (File.Exists(destPath + ".bin.gz")) - { - File.Move(srcPath + ".bin.gz", destPath + ".bin.gz"); - File.Move(srcPath + "_config.json", destPath + "_config.json"); - - localSettingsService.SaveSetting("EyeHome_EyeModel", destPath + ".bin.gz"); - } else if (File.Exists(destPath + ".onnx")) - { - File.Move(srcPath + ".onnx", destPath + ".onnx"); - - localSettingsService.SaveSetting("EyeHome_EyeModel", destPath + ".onnx"); - } else - { - return (false, "Trained model not found"); - } - - await eyePipelineManager.LoadInferenceAsync(); - - if (localSettingsService.ReadSetting("AppSettings_ShareEyeData")) + if (routine != CalibrationRoutine.Routines.TutorialStep) { - var userCal = Path.Combine(Utils.ModelDataDirectory, "user_cal.bin"); - await dataUploaderService.UploadDataAsync(userCal); + var srcPath = Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest"); + var destPath = Path.Combine(Utils.ModelsDirectory, + $"tuned_temporal_eye_tracking_{DateTime.Now:yyyyMMdd_HHmmss}"); + + if (File.Exists(destPath + ".bin.gz")) + { + File.Move(srcPath + ".bin.gz", destPath + ".bin.gz"); + File.Move(srcPath + "_config.json", destPath + "_config.json"); + + localSettingsService.SaveSetting("EyeHome_EyeModel", destPath + ".bin.gz"); + } + else if (File.Exists(destPath + ".onnx")) + { + File.Move(srcPath + ".onnx", destPath + ".onnx"); + + localSettingsService.SaveSetting("EyeHome_EyeModel", destPath + ".onnx"); + } + else + { + return (false, "Trained model not found"); + } + + await eyePipelineManager.LoadInferenceAsync(); + + if (localSettingsService.ReadSetting("AppSettings_ShareEyeData")) + { + var userCal = Path.Combine(Utils.ModelDataDirectory, "user_cal.bin"); + await dataUploaderService.UploadDataAsync(userCal); + } } await overlayProgram.WaitForExitAsync(); diff --git a/src/Baballonia/App.axaml.cs b/src/Baballonia/App.axaml.cs index 991492a8..458986c6 100644 --- a/src/Baballonia/App.axaml.cs +++ b/src/Baballonia/App.axaml.cs @@ -138,6 +138,8 @@ public override void OnFrameworkInitializationCompleted() services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); + services.AddTransient(); services.AddTransient(); services.AddTransient(); services.AddTransient(); diff --git a/src/Baballonia/Baballonia.csproj b/src/Baballonia/Baballonia.csproj index 6e527989..f756c5e2 100644 --- a/src/Baballonia/Baballonia.csproj +++ b/src/Baballonia/Baballonia.csproj @@ -107,6 +107,9 @@ True Resources.resx + + EyeTrainingView.axaml + diff --git a/src/Baballonia/Contracts/IVROverlay.cs b/src/Baballonia/Contracts/IVROverlay.cs index 076676b4..33df5324 100644 --- a/src/Baballonia/Contracts/IVROverlay.cs +++ b/src/Baballonia/Contracts/IVROverlay.cs @@ -1,10 +1,11 @@ using Baballonia.Helpers; using System; +using System.Collections.Generic; using System.Threading.Tasks; namespace Baballonia.Contracts; public interface IVROverlay : IDisposable { - public Task<(bool success, string status)> EyeTrackingCalibrationRequested(CalibrationRoutine.Routines calibrationRoutine); + public Task<(bool success, string status)> EyeTrackingCalibrationRequested(CalibrationRoutine.Routines calibrationRoutine, List args); } diff --git a/src/Baballonia/Converters/FileNameConverter.cs b/src/Baballonia/Converters/FileNameConverter.cs new file mode 100644 index 00000000..6676a3a0 --- /dev/null +++ b/src/Baballonia/Converters/FileNameConverter.cs @@ -0,0 +1,17 @@ +using Avalonia.Data.Converters; +using System; +using System.Globalization; +using System.IO; + +namespace Baballonia.Converters; + +public class FileNameConverter: IValueConverter +{ + public object? Convert(object? value, Type targetType, object? parameter, CultureInfo culture) + { + return Path.GetFileName((string?) value); + } + + object? IValueConverter.ConvertBack(object? value, Type targetType, object? parameter, CultureInfo culture) + => throw new NotSupportedException(); +} diff --git a/src/Baballonia/Helpers/CalibrationRoutine.cs b/src/Baballonia/Helpers/CalibrationRoutine.cs index a2ea4daa..eebfd97e 100644 --- a/src/Baballonia/Helpers/CalibrationRoutine.cs +++ b/src/Baballonia/Helpers/CalibrationRoutine.cs @@ -10,9 +10,8 @@ public enum Routines { BasicCalibration, BasicCalibrationNoTutorial, - GazeOnly, - BlinkOnly, - TrainOnly, + TutorialStep, + TrainModel, } public static readonly Dictionary Map = Enum.GetValues().ToDictionary(i => i.ToString(), i => i); /* diff --git a/src/Baballonia/Utils/Utils.cs b/src/Baballonia/Utils/Utils.cs index 3eca23f9..b26af504 100644 --- a/src/Baballonia/Utils/Utils.cs +++ b/src/Baballonia/Utils/Utils.cs @@ -63,6 +63,10 @@ public static class Utils ? Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "ProjectBabble", "ModelData") : AppContext.BaseDirectory; + public static readonly string ModelTrainingDataDirectory = IsSupportedDesktopOS + ? Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "ProjectBabble", "ModelTrainingData") + : AppContext.BaseDirectory; + public static readonly string VrcftLibsDirectory = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "VRCFaceTracking", "CustomLibs"); diff --git a/src/Baballonia/ViewLocator.cs b/src/Baballonia/ViewLocator.cs index cbfaefdc..327bc9d6 100644 --- a/src/Baballonia/ViewLocator.cs +++ b/src/Baballonia/ViewLocator.cs @@ -20,6 +20,7 @@ public ViewLocator() RegisterViewFactory(); RegisterViewFactory(); RegisterViewFactory(); + RegisterViewFactory(); RegisterViewFactory(); RegisterViewFactory(); diff --git a/src/Baballonia/ViewModels/MainViewModel.cs b/src/Baballonia/ViewModels/MainViewModel.cs index 952157ff..121f6d8b 100644 --- a/src/Baballonia/ViewModels/MainViewModel.cs +++ b/src/Baballonia/ViewModels/MainViewModel.cs @@ -41,6 +41,7 @@ private void SetOverlay(bool show) new(typeof(CalibrationViewModel), "EditRegular", Resources.Calibration_Title_Header), // Calibration new(typeof(FirmwareViewModel), "DeveloperBoardRegular", Resources.Firmware_Title_Header), // Firmware new(typeof(VrcViewModel), "CommentRegular", "VRChat"), // VRChat. No translation :P + new(typeof(EyeTrainingViewModel), "HeadsetVrRegular", "Eye training management"), // Eye Training new(typeof(OutputPageViewModel), "TextFirstLineRegular", Resources.Output_Title_Header), // Output new(typeof(AppSettingsViewModel), "SettingsRegular", Resources.Settings_Title_Header), // Settings new(typeof(AboutPageViewModel), "InfoRegular", Resources.About_Title_Header), // About diff --git a/src/Baballonia/ViewModels/SplitViewPane/EyeTrainingViewModel.cs b/src/Baballonia/ViewModels/SplitViewPane/EyeTrainingViewModel.cs new file mode 100644 index 00000000..2154d964 --- /dev/null +++ b/src/Baballonia/ViewModels/SplitViewPane/EyeTrainingViewModel.cs @@ -0,0 +1,147 @@ +using Avalonia.Controls; +using Avalonia.Media; +using Baballonia.Contracts; +using Baballonia.Helpers; +using Baballonia.Services; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.IO; +using System.Linq; +using System.Threading.Tasks; + +namespace Baballonia.ViewModels.SplitViewPane; + +public partial class EyeTrainingViewModel : ViewModelBase +{ + public partial class CalibrationStep : ObservableObject + { + [ObservableProperty] private string _title; + [ObservableProperty] ObservableCollection _files; + [ObservableProperty] string _selectedFile; + + private readonly string _stepName; + private readonly IVROverlay _vrOverlay; + + public CalibrationStep(IVROverlay vrOverlay, string title, string name) + { + _title = title; + _stepName = name; + _files = new ObservableCollection(); + _vrOverlay = vrOverlay; + + GetExistingFiles(); + + if (_files.Count > 0) + _selectedFile = _files[0]; + } + + public void GetExistingFiles() + { + _files.Clear(); + var calibrationFolder = Path.Combine(Utils.ModelTrainingDataDirectory, _stepName); + if (!Directory.Exists(calibrationFolder)) + { + return; + } + + foreach (var file in Directory.GetFiles(calibrationFolder, "*.bin", SearchOption.AllDirectories)) + { + _files.Add(file); + } + } + + [RelayCommand] + private async Task ReRecord() + { + var res = await Task.Run(async () => + { + try + { + return await _vrOverlay.EyeTrackingCalibrationRequested(CalibrationRoutine.Routines.TutorialStep, new List + { + _stepName, + Path.Combine(Utils.ModelTrainingDataDirectory, _stepName, $"{DateTime.Now:yyyyMMdd_HHmmss}.bin") + }); + } + catch (Exception ex) + { + return (false, ex.Message); + } + } + ); + + GetExistingFiles(); + //if (res.Item1) + //{ + // SelectedCalibrationTextBlock.Foreground = new SolidColorBrush(Colors.Green); + //} + //else + //{ + // SelectedCalibrationTextBlock.Foreground = new SolidColorBrush(Colors.Red); + // _logger.LogError(res.Item2); + //} + } + } + + public ObservableCollection CalibrationSteps { get; set; } + + private readonly IVROverlay _vrOverlay; + private readonly ILogger _logger; + private readonly EyePipelineEventBus _eyePipelineEventBus; + + public Button RetrainButton { get; set; } + + public EyeTrainingViewModel( + IVROverlay vrOverlay, + ILogger logger, + EyePipelineEventBus eyePipelineEventBus) + { + _vrOverlay = vrOverlay; + _logger = logger; + _eyePipelineEventBus = eyePipelineEventBus; + + CalibrationSteps = new ObservableCollection + { + new CalibrationStep(_vrOverlay, "Gaze", "gaze"), + new CalibrationStep(_vrOverlay, "Blink", "blink"), + new CalibrationStep(_vrOverlay, "Eyebrows", "brow"), + new CalibrationStep(_vrOverlay, "Squinting", "squint"), + new CalibrationStep(_vrOverlay, "Widening", "widen"), + }; + } + + + [RelayCommand] + private async Task RetrainModel() + { + + var paths = CalibrationSteps.Select(cs => cs.SelectedFile).ToList(); + + var res = await Task.Run(async () => + { + try + { + return await _vrOverlay.EyeTrackingCalibrationRequested(CalibrationRoutine.Routines.TrainModel, paths); + } + catch (Exception ex) + { + return (false, ex.Message); + } + } + ); + if (res.Item1) + { + RetrainButton.Foreground = new SolidColorBrush(Colors.Green); + } + else + { + RetrainButton.Foreground = new SolidColorBrush(Colors.Red); + _logger.LogError(res.Item2); + } + } + +} diff --git a/src/Baballonia/ViewModels/SplitViewPane/HomePageViewModel.cs b/src/Baballonia/ViewModels/SplitViewPane/HomePageViewModel.cs index 062df520..52154b71 100644 --- a/src/Baballonia/ViewModels/SplitViewPane/HomePageViewModel.cs +++ b/src/Baballonia/ViewModels/SplitViewPane/HomePageViewModel.cs @@ -16,6 +16,7 @@ using Microsoft.Extensions.Logging; using OpenCvSharp; using System; +using System.Collections.Generic; using System.Collections.ObjectModel; using System.ComponentModel; using System.Linq; @@ -251,21 +252,21 @@ private void EyeImageUpdateHandler(Mat image) switch (Camera) { case Camera.Left: - { - var leftHalf = new OpenCvSharp.Rect(0, 0, width / 2, height); - var leftRoi = new Mat(image, leftHalf); + { + var leftHalf = new OpenCvSharp.Rect(0, 0, width / 2, height); + var leftRoi = new Mat(image, leftHalf); - UpdateBitmap(leftRoi); - break; - } + UpdateBitmap(leftRoi); + break; + } case Camera.Right: - { - var rightHalf = new OpenCvSharp.Rect(width / 2, 0, width / 2, height); - var rightRoi = new Mat(image, rightHalf); - UpdateBitmap(rightRoi); - break; - } + { + var rightHalf = new OpenCvSharp.Rect(width / 2, 0, width / 2, height); + var rightRoi = new Mat(image, rightHalf); + UpdateBitmap(rightRoi); + break; + } } } else if (channels == 2) @@ -608,7 +609,7 @@ public async Task StartCamera(CameraControllerModel model) await StartCameraWithMaximization(model, startMaximized: true); } - private async Task StartCameraWithMaximization(CameraControllerModel model, bool startMaximized) + private async Task StartCameraWithMaximization(CameraControllerModel model, bool startMaximized) { try { @@ -682,7 +683,7 @@ private async Task RequestVRCalibration() { try { - return await _vrOverlay.EyeTrackingCalibrationRequested(RequestedVRCalibration); + return await _vrOverlay.EyeTrackingCalibrationRequested(RequestedVRCalibration, new List { }); } catch (Exception ex) { diff --git a/src/Baballonia/Views/EyeTrainingView.axaml b/src/Baballonia/Views/EyeTrainingView.axaml new file mode 100644 index 00000000..68554274 --- /dev/null +++ b/src/Baballonia/Views/EyeTrainingView.axaml @@ -0,0 +1,114 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Baballonia/Views/EyeTrainingView.axaml.cs b/src/Baballonia/Views/EyeTrainingView.axaml.cs new file mode 100644 index 00000000..dc929824 --- /dev/null +++ b/src/Baballonia/Views/EyeTrainingView.axaml.cs @@ -0,0 +1,34 @@ +using Avalonia.Controls; +using Avalonia.Interactivity; +using Baballonia.ViewModels.SplitViewPane; + +namespace Baballonia.Views; + +public partial class EyeTrainingView : ViewBase +{ + public EyeTrainingView() + { + InitializeComponent(); + + Loaded += (_, _) => + { + if (DataContext is not EyeTrainingViewModel vm) return; + + vm.RetrainButton = this.Find