diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 87dd51dc..1f5585ae 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -49,6 +49,16 @@ jobs: - name: Download dependencies run: ./download_dependencies.sh + - name: Setup Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + rust-src-dir: src/babble_trainer + + - name: Build Rust trainer + run: | + cd src/babble_trainer + cargo build --release --lib + #- name: Install dependencies # run: | # dotnet tool install -g vpk @@ -89,6 +99,16 @@ jobs: shell: pwsh run: ./download_dependencies.ps1 + - name: Setup Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + rust-src-dir: src/babble_trainer + + - name: Build Rust trainer + run: | + cd src/babble_trainer + cargo build --release --lib + - name: Build project run: | cd src/Baballonia.Desktop diff --git a/.gitmodules b/.gitmodules index 36dad3ba..2024a7ff 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "src/espflash"] path = src/espflash url = https://github.com/esp-rs/espflash.git +[submodule "src/babble_trainer"] + path = src/babble_trainer + url = https://github.com/acmdf/babble_trainer diff --git a/Baballonia.sln b/Baballonia.sln index e151a5ef..1ea01223 100644 --- a/Baballonia.sln +++ b/Baballonia.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 17 -VisualStudioVersion = 17.13.35919.96 +# Visual Studio Version 18 +VisualStudioVersion = 18.4.11605.240 stable MinimumVisualStudioVersion = 10.0.40219.1 Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Baballonia", "src\Baballonia\Baballonia.csproj", "{00505DCC-588E-4E46-8F91-2AF6A88CBC78}" EndProject @@ -47,6 +47,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Baballonia.CaptureBin.IO", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Baballonia.LibV4L2Capture", "src\Baballonia.LibV4L2Capture\Baballonia.LibV4L2Capture.csproj", "{02E9F6A2-A443-491D-93FF-6F002F3C494F}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "babble_trainer", "src\babble_trainer\bindings\bindings.csproj", "{692165EE-1668-066F-9D57-CC06413C6A55}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -126,6 +128,10 @@ Global {02E9F6A2-A443-491D-93FF-6F002F3C494F}.Debug|Any CPU.Build.0 = Debug|Any CPU {02E9F6A2-A443-491D-93FF-6F002F3C494F}.Release|Any CPU.ActiveCfg = Release|Any CPU {02E9F6A2-A443-491D-93FF-6F002F3C494F}.Release|Any CPU.Build.0 = Release|Any CPU + {692165EE-1668-066F-9D57-CC06413C6A55}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {692165EE-1668-066F-9D57-CC06413C6A55}.Debug|Any CPU.Build.0 = Debug|Any CPU + {692165EE-1668-066F-9D57-CC06413C6A55}.Release|Any CPU.ActiveCfg = Release|Any CPU + {692165EE-1668-066F-9D57-CC06413C6A55}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/Baballonia.Android/Calibration/AndroidOverlayTrainerCombo.cs b/src/Baballonia.Android/Calibration/AndroidOverlayTrainerCombo.cs index 471701b0..908bce4f 100644 --- a/src/Baballonia.Android/Calibration/AndroidOverlayTrainerCombo.cs +++ b/src/Baballonia.Android/Calibration/AndroidOverlayTrainerCombo.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Threading.Tasks; using Baballonia.Contracts; using Baballonia.Helpers; @@ -11,7 +12,7 @@ public Task EyeTrackingCalibrationRequested(string calibrationRoutine) return Task.CompletedTask; } - public Task<(bool success, string status)> EyeTrackingCalibrationRequested(CalibrationRoutine.Routines calibrationRoutine) + public Task<(bool success, string status)> EyeTrackingCalibrationRequested(CalibrationRoutine.Routines calibrationRoutine, List args) { return Task.FromResult((true, "Not Supported")); } diff --git a/src/Baballonia.Desktop/Baballonia.Desktop.csproj b/src/Baballonia.Desktop/Baballonia.Desktop.csproj index 3f351ec2..75a523e4 100644 --- a/src/Baballonia.Desktop/Baballonia.Desktop.csproj +++ b/src/Baballonia.Desktop/Baballonia.Desktop.csproj @@ -61,6 +61,7 @@ + diff --git a/src/Baballonia.Desktop/Calibration/FrameCollector.cs b/src/Baballonia.Desktop/Calibration/FrameCollector.cs index 3249d071..75a852c9 100644 --- a/src/Baballonia.Desktop/Calibration/FrameCollector.cs +++ b/src/Baballonia.Desktop/Calibration/FrameCollector.cs @@ -80,7 +80,7 @@ public void WriteBin(string path) _frames.Clear(); } - CaptureBin.IO.CaptureBin.WriteAll(Path.Combine(Utils.ModelDataDirectory, path), copy); + CaptureBin.IO.CaptureBin.WriteAll(path, copy); } } public class BinCollector(uint headerFlags) @@ -149,6 +149,6 @@ public void WriteBin(string path) _frames.Clear(); } - CaptureBin.IO.CaptureBin.WriteAll(Path.Combine(Utils.ModelDataDirectory, path), copy); + CaptureBin.IO.CaptureBin.WriteAll(path, copy); } } diff --git a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs index 53a83228..bd0db652 100644 --- a/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs +++ b/src/Baballonia.Desktop/Calibration/ICalibrationRoutine.cs @@ -159,7 +159,7 @@ public void Dispose() } } -public class GazeCaptureStep(IEyePipelineEventBus bus, TimeSpan time) : BasePositionalAwareEyeCaptureStep(bus, "gaze", +public class GazeCaptureStep(IEyePipelineEventBus bus, TimeSpan time, string fileName) : BasePositionalAwareEyeCaptureStep(bus, "gaze", fileName, CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_IN_MOVEMENT | CaptureFlags.FLAG_VERSION_BIT1 | @@ -168,7 +168,7 @@ public class GazeCaptureStep(IEyePipelineEventBus bus, TimeSpan time) : BasePosi private Stopwatch _posDataTimer = new(); private readonly TimeSpan _posDataTimeout = TimeSpan.FromSeconds(0.2); - public GazeCaptureStep(IEyePipelineEventBus bus) : this(bus, TimeSpan.FromSeconds(120)) + public GazeCaptureStep(IEyePipelineEventBus bus, string fileName) : this(bus, TimeSpan.FromSeconds(120), fileName) { } @@ -204,6 +204,7 @@ public override void OnNewEyeFrame(EyePipelineEvents.NewTransformedFrameEvent fr public class BasePositionalAwareEyeCaptureStep( IEyePipelineEventBus eyePipelineEvent, string name, + string fileName, uint flags, TimeSpan time) : PositionalAwareCaptureStep(name, flags, time) @@ -224,13 +225,14 @@ public override async Task ExecuteAsync(OverlayMessageDispatcher dispatcher, Can if (ct.IsCancellationRequested) return; - PositionalBinCollector.WriteBin(Name + ".bin"); + PositionalBinCollector.WriteBin(fileName); } } public class BaseEyeCaptureStep( IEyePipelineEventBus eyePipelineEvent, string name, + string fileName, uint flags, TimeSpan time, float lid = 0, @@ -257,7 +259,7 @@ public override async Task ExecuteAsync(OverlayMessageDispatcher dispatcher, Can if (ct.IsCancellationRequested) return; - BinCollector.WriteBin(Name + ".bin"); + BinCollector.WriteBin(fileName); } public override Frame AddFrame(Mat[] images) @@ -297,8 +299,8 @@ public async Task ExecuteAsync(OverlayMessageDispatcher dispatcher, Cancellation dispatcher.Dispatch(new RunVariableLenghtRoutinePacket(Name, TimeSpan.FromSeconds(120))); var onProgressHandler = (TrainerProgressReportPacket packet) => { dispatcher.Dispatch(packet); }; overlayTrainer.OnProgress += onProgressHandler; - overlayTrainer.RunTraining(Path.Combine(Utils.ModelDataDirectory, "user_cal.bin"), - Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest.onnx")); + await overlayTrainer.RunTraining(Path.Combine(Utils.ModelDataDirectory, "user_cal.bin"), + Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest")); await overlayTrainer.WaitAsync(); overlayTrainer.OnProgress -= onProgressHandler; @@ -307,14 +309,14 @@ public async Task ExecuteAsync(OverlayMessageDispatcher dispatcher, Cancellation public class EyeCaptureStepFactory(IEyePipelineEventBus eyePipelineEvent) { - public BaseEyeCaptureStep Create(string name, uint flags, TimeSpan time, + public BaseEyeCaptureStep Create(string name, string fileName, uint flags, TimeSpan time, float lid = 0, float browRaise = 0, float browAngry = 0, float widen = 0, float squint = 0, float dilate = 0) => - new(eyePipelineEvent, name, flags, time, lid, browRaise, browAngry, widen, squint, dilate); + new(eyePipelineEvent, name, fileName, flags, time, lid, browRaise, browAngry, widen, squint, dilate); } public class MergeBinsStep(params string[] binNames) : ICalibrationStep @@ -330,8 +332,7 @@ public Task ExecuteAsync(OverlayMessageDispatcher dispatcher, CancellationToken private static void MergeBins(string result, params string[] inputs) { var resultPath = Path.Combine(Utils.ModelDataDirectory, result); - var inputPaths = inputs.Select(i => Path.Combine(Utils.ModelDataDirectory, i)).ToArray(); - CaptureBin.IO.CaptureBin.Concatenate(resultPath, inputPaths); + CaptureBin.IO.CaptureBin.Concatenate(resultPath, inputs); } } @@ -340,37 +341,92 @@ public class EyeCalibration( ITrainerService trainer, IEyePipelineEventBus eyePipelineEventBus) { + public IEnumerable GetCalibrationStep(string stepName, string fileName) + { + return stepName switch + { + "gaze" => new List + { + new BaseTutorialStep("gazetutorial"), + new GazeCaptureStep(eyePipelineEventBus, fileName), + new CommandDispatchStep("close") + }, + "blink" => new List + { + new BaseTutorialStep("blinktutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create(stepName, fileName, + CaptureFlags.FLAG_GOOD_DATA | + CaptureFlags.FLAG_IN_MOVEMENT | + CaptureFlags.FLAG_VERSION_BIT1, + TimeSpan.FromSeconds(20) + ), + new CommandDispatchStep("close") + }, + "widen" => new List + { + new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create(stepName,fileName, + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), widen: 1, lid: 1), + new CommandDispatchStep("close") + }, + "squint" => new List + { + new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create(stepName,fileName, + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), squint: 1, lid: 1), + new CommandDispatchStep("close") + }, + "brow" => new List + { + new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create(stepName, fileName, + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), browAngry: 1, lid: 1), + new CommandDispatchStep("close") + } + }; + } + + public IEnumerable TrainCalibration(List args) + { + return new List + { + new MergeBinsStep(args.ToArray()), + new TrainerCalibrationStep(trainer), + new CommandDispatchStep("close") + }; + } + public IEnumerable BasicAllCalibration() { List steps = [ new BaseTutorialStep("gazetutorial"), - new GazeCaptureStep(eyePipelineEventBus), + new GazeCaptureStep(eyePipelineEventBus, Path.Combine(Utils.ModelDataDirectory, "gaze.bin")), new BaseTutorialStep("blinktutorial", TimeSpan.FromSeconds(10)), - eyeCaptureStepFactory.Create("blink", + eyeCaptureStepFactory.Create("blink", Path.Combine(Utils.ModelDataDirectory, "blink.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_IN_MOVEMENT | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), lid: 0 ), - // new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(10)), - // eyeCaptureStepFactory.Create("widen", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), widen: 1, lid: 1), - // - // new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(10)), - // eyeCaptureStepFactory.Create("squint", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), squint: 1, lid: 1), - // - // new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(10)), - // eyeCaptureStepFactory.Create("brow", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), browAngry: 1, lid: 1), + new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create("widen", Path.Combine(Utils.ModelDataDirectory, "widen.bin"), + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), widen: 1, lid: 1), + + new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create("squint", Path.Combine(Utils.ModelDataDirectory, "squint.bin"), + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), squint: 1, lid: 1), + + new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(10)), + eyeCaptureStepFactory.Create("brow", Path.Combine(Utils.ModelDataDirectory, "brow.bin"), + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20), browAngry: 1, lid: 1), // steps.Add(new BaseTutorialStep("covergencetutorial")); // steps.Add(_eyeCaptureStepFactory.Create("covergence", // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_WHATEVER_NOT_IMPLEMENTED)); - // new MergeBinsStep("gaze.bin", "blink.bin", "widen.bin", "squint.bin", "brow.bin"), - new MergeBinsStep("gaze.bin", "blink.bin"), + new MergeBinsStep(Path.Combine(Utils.ModelDataDirectory, "gaze.bin"), Path.Combine(Utils.ModelDataDirectory, "blink.bin"), Path.Combine(Utils.ModelDataDirectory, "widen.bin"), Path.Combine(Utils.ModelDataDirectory, "squint.bin"), Path.Combine(Utils.ModelDataDirectory, "brow.bin")), + //new MergeBinsStep(Path.Combine(Utils.ModelDataDirectory, "gaze.bin"), Path.Combine(Utils.ModelDataDirectory, "blink.bin")), new TrainerCalibrationStep(trainer), new CommandDispatchStep("close") @@ -384,60 +440,9 @@ public IEnumerable BasicAllCalibrationQuick() List steps = [ new BaseTutorialStep("gazetutorialshort", TimeSpan.FromSeconds(5)), - new GazeCaptureStep(eyePipelineEventBus, TimeSpan.FromSeconds(10)), - new BaseTutorialStep("blinktutorial", TimeSpan.FromSeconds(4)), - eyeCaptureStepFactory.Create("blink", - CaptureFlags.FLAG_GOOD_DATA | - CaptureFlags.FLAG_IN_MOVEMENT | - CaptureFlags.FLAG_VERSION_BIT1 | - CaptureFlags.FLAG_ROUTINE_BIT1, - TimeSpan.FromSeconds(20) - ), - - //new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(4)), - //eyeCaptureStepFactory.Create("widen", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), - // - //new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(4)), - //eyeCaptureStepFactory.Create("squint", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), - // - //new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(4)), - //eyeCaptureStepFactory.Create("brow", - // CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), - - //new MergeBinsStep("gaze.bin", "blink.bin", "widen.bin", "squint.bin", "brow.bin"), - new MergeBinsStep("gaze.bin", "blink.bin"), - new TrainerCalibrationStep(trainer), - new CommandDispatchStep("close") - - ]; - - return steps; - } - - public IEnumerable GazeCalibration() - { - List steps = - [ - new BaseTutorialStep("gazetutorialshort", TimeSpan.FromSeconds(5)), - new GazeCaptureStep(eyePipelineEventBus), - - new MergeBinsStep("gaze.bin", "blink.bin"), - new TrainerCalibrationStep(trainer), - new CommandDispatchStep("close") - - ]; - - return steps; - } - - public IEnumerable BlinkCalibration() - { - List steps = - [ + new GazeCaptureStep(eyePipelineEventBus, Path.Combine(Utils.ModelDataDirectory, "gaze.bin")), new BaseTutorialStep("blinktutorial", TimeSpan.FromSeconds(4)), - eyeCaptureStepFactory.Create("blink", + eyeCaptureStepFactory.Create("blink", Path.Combine(Utils.ModelDataDirectory, "blink.bin"), CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_IN_MOVEMENT | CaptureFlags.FLAG_VERSION_BIT1 | @@ -445,7 +450,20 @@ public IEnumerable BlinkCalibration() TimeSpan.FromSeconds(20) ), - new MergeBinsStep("gaze.bin", "blink.bin"), + new BaseTutorialStep("widentutorial", TimeSpan.FromSeconds(4)), + eyeCaptureStepFactory.Create("widen", Path.Combine(Utils.ModelDataDirectory, "widen.bin"), + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), + + new BaseTutorialStep("squinttutorial", TimeSpan.FromSeconds(4)), + eyeCaptureStepFactory.Create("squint", Path.Combine(Utils.ModelDataDirectory, "squint.bin"), + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), + + new BaseTutorialStep("browtutorial", TimeSpan.FromSeconds(4)), + eyeCaptureStepFactory.Create("brow", Path.Combine(Utils.ModelDataDirectory, "brow.bin"), + CaptureFlags.FLAG_GOOD_DATA | CaptureFlags.FLAG_VERSION_BIT1, TimeSpan.FromSeconds(20)), + + new MergeBinsStep(Path.Combine(Utils.ModelDataDirectory, "gaze.bin"), Path.Combine(Utils.ModelDataDirectory, "blink.bin"), Path.Combine(Utils.ModelDataDirectory, "widen.bin"), Path.Combine(Utils.ModelDataDirectory, "squint.bin"), Path.Combine(Utils.ModelDataDirectory, "brow.bin")), + //new MergeBinsStep(Path.Combine(Utils.ModelDataDirectory, "gaze.bin"), Path.Combine(Utils.ModelDataDirectory, "blink.bin")), new TrainerCalibrationStep(trainer), new CommandDispatchStep("close") diff --git a/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs b/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs index 5f86fcca..623462e0 100644 --- a/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs +++ b/src/Baballonia.Desktop/Calibration/OverlayCalibrationService.cs @@ -5,6 +5,7 @@ using OverlaySDK; using System; using System.IO; +using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -28,7 +29,7 @@ public void Dispose() } public async Task<(bool success, string status)> EyeTrackingCalibrationRequested( - CalibrationRoutine.Routines routine) + CalibrationRoutine.Routines routine, List args) { if (!overlayProgram.CanStart()) { @@ -57,8 +58,8 @@ public void Dispose() { CalibrationRoutine.Routines.BasicCalibration => eyeCalibration.BasicAllCalibration(), CalibrationRoutine.Routines.BasicCalibrationNoTutorial => eyeCalibration.BasicAllCalibrationQuick(), - CalibrationRoutine.Routines.GazeOnly => eyeCalibration.GazeCalibration(), - CalibrationRoutine.Routines.BlinkOnly => eyeCalibration.BlinkCalibration(), + CalibrationRoutine.Routines.TutorialStep => eyeCalibration.GetCalibrationStep(args[0], args[1]), + CalibrationRoutine.Routines.TrainModel => eyeCalibration.TrainCalibration(args), _ => eyeCalibration.BasicAllCalibration() }; foreach (var calibrationStep in steps) @@ -66,19 +67,37 @@ public void Dispose() await calibrationStep.ExecuteAsync(messageDispatcher, _tokenSource.Token); } - var srcPath = Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest.onnx"); - var destPath = Path.Combine(Utils.ModelsDirectory, - $"tuned_temporal_eye_tracking_{DateTime.Now:yyyyMMdd_HHmmss}.onnx"); - - File.Move(srcPath, destPath); - - localSettingsService.SaveSetting("EyeHome_EyeModel", destPath); - await eyePipelineManager.LoadInferenceAsync(); - - if (localSettingsService.ReadSetting("AppSettings_ShareEyeData")) + if (routine != CalibrationRoutine.Routines.TutorialStep) { - var userCal = Path.Combine(Utils.ModelDataDirectory, "user_cal.bin"); - await dataUploaderService.UploadDataAsync(userCal); + var srcPath = Path.Combine(Utils.ModelDataDirectory, "tuned_temporal_eye_tracking_latest"); + var destPath = Path.Combine(Utils.ModelsDirectory, + $"tuned_temporal_eye_tracking_{DateTime.Now:yyyyMMdd_HHmmss}"); + + if (File.Exists(destPath + ".bin.gz")) + { + File.Move(srcPath + ".bin.gz", destPath + ".bin.gz"); + File.Move(srcPath + "_config.json", destPath + "_config.json"); + + localSettingsService.SaveSetting("EyeHome_EyeModel", destPath + ".bin.gz"); + } + else if (File.Exists(destPath + ".onnx")) + { + File.Move(srcPath + ".onnx", destPath + ".onnx"); + + localSettingsService.SaveSetting("EyeHome_EyeModel", destPath + ".onnx"); + } + else + { + return (false, "Trained model not found"); + } + + await eyePipelineManager.LoadInferenceAsync(); + + if (localSettingsService.ReadSetting("AppSettings_ShareEyeData")) + { + var userCal = Path.Combine(Utils.ModelDataDirectory, "user_cal.bin"); + await dataUploaderService.UploadDataAsync(userCal); + } } await overlayProgram.WaitForExitAsync(); diff --git a/src/Baballonia.Desktop/Calibration/RustTrainerService.cs b/src/Baballonia.Desktop/Calibration/RustTrainerService.cs new file mode 100644 index 00000000..8c621135 --- /dev/null +++ b/src/Baballonia.Desktop/Calibration/RustTrainerService.cs @@ -0,0 +1,91 @@ +using System; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using Baballonia.Contracts; +using babble_model.Net.Sys; +using Microsoft.Extensions.Logging; +using OverlaySDK.Packets; + +namespace Baballonia.Desktop.Calibration; + +public partial class RustTrainerService : ITrainerService +{ + private readonly object _lock = new(); + + public event Action? OnProgress; + + static event Action? GlobalProgress; + static TaskCompletionSource? tcs; + + [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvCdecl) })] + static void HandleProgress(TrainingDataCallback data) + { + Console.WriteLine($"Recieved {data.callback_type}: {data.low}/{data.high} ({data.loss})"); + TrainerProgressReportPacket progress; + if (data.callback_type == CallbackType.Batch) + { + progress = new TrainerProgressReportPacket("Batch", data.low, data.high, data.loss); + } else if (data.callback_type == CallbackType.Epoch) + { + progress = new TrainerProgressReportPacket("Epoch", data.low, data.high, data.loss); + } + else + { + if (tcs != null) + { + tcs.TrySetResult(true); + } else + { + Console.WriteLine("tcs is null when trying to set result"); + } + return; + } + + GlobalProgress?.Invoke(progress); + } + + unsafe void CallTrainer(string usercalbinPath, string outputfilePath) + { + var userCalBytes = System.Text.Encoding.UTF8.GetBytes(usercalbinPath); + fixed (byte* userCalBytesPtr = userCalBytes) + { + var outputFileBytes = System.Text.Encoding.UTF8.GetBytes(outputfilePath); + fixed (byte* outputFileBytesPtr = outputFileBytes) + { + NativeMethods.trainModel(userCalBytesPtr, outputFileBytesPtr, &HandleProgress); + } + } + } + + public async Task RunTraining(string usercalbinPath, string outputfilePath) + { + if (!File.Exists(usercalbinPath)) + throw new FileNotFoundException(usercalbinPath + " not found"); + + + lock (_lock) + { + tcs = new TaskCompletionSource(); + GlobalProgress += OnProgress; + } + await Task.Run(() => CallTrainer(usercalbinPath, outputfilePath)); + } + + public Task WaitAsync() + { + return tcs != null ? tcs.Task : Task.CompletedTask; + } + + public void Dispose() + { + lock (_lock) + { + } + } +} diff --git a/src/Baballonia.Desktop/Calibration/TrainerService.cs b/src/Baballonia.Desktop/Calibration/TrainerService.cs index 50378bbd..df4aa5e6 100644 --- a/src/Baballonia.Desktop/Calibration/TrainerService.cs +++ b/src/Baballonia.Desktop/Calibration/TrainerService.cs @@ -81,7 +81,7 @@ void NewLineEventHandler(object sender, DataReceivedEventArgs dataReceivedEventA return; } - public void RunTraining(string usercalbinPath, string outputfilePath) + public async Task RunTraining(string usercalbinPath, string outputfilePath) { if (!File.Exists(usercalbinPath)) throw new FileNotFoundException(usercalbinPath + " not found"); diff --git a/src/Baballonia.Desktop/Program.cs b/src/Baballonia.Desktop/Program.cs index c187d657..6cecee0b 100644 --- a/src/Baballonia.Desktop/Program.cs +++ b/src/Baballonia.Desktop/Program.cs @@ -44,7 +44,7 @@ public static int Main(string[] args) App.RegisterPlatformSpecificServices(collection => { collection.AddSingleton(); - collection.AddSingleton(); + collection.AddSingleton(); collection.AddSingleton(); collection.AddSingleton(); }); diff --git a/src/Baballonia.Tests/Trainer/TrainerServiceTest.cs b/src/Baballonia.Tests/Trainer/TrainerServiceTest.cs index a50520e1..8cf2d3f1 100644 --- a/src/Baballonia.Tests/Trainer/TrainerServiceTest.cs +++ b/src/Baballonia.Tests/Trainer/TrainerServiceTest.cs @@ -17,7 +17,7 @@ public async Task Test() var log = factory.CreateLogger(); TrainerService service = new TrainerService(log); - service.RunTraining("test.bin", "model.onnx"); + await service.RunTraining("test.bin", "model.onnx"); await service.WaitAsync(); } diff --git a/src/Baballonia/App.axaml.cs b/src/Baballonia/App.axaml.cs index 991492a8..458986c6 100644 --- a/src/Baballonia/App.axaml.cs +++ b/src/Baballonia/App.axaml.cs @@ -138,6 +138,8 @@ public override void OnFrameworkInitializationCompleted() services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); + services.AddTransient(); services.AddTransient(); services.AddTransient(); services.AddTransient(); diff --git a/src/Baballonia/Baballonia.csproj b/src/Baballonia/Baballonia.csproj index 453c4bfe..f756c5e2 100644 --- a/src/Baballonia/Baballonia.csproj +++ b/src/Baballonia/Baballonia.csproj @@ -107,6 +107,9 @@ True Resources.resx + + EyeTrainingView.axaml + @@ -153,6 +156,7 @@ + diff --git a/src/Baballonia/Contracts/ITrainerService.cs b/src/Baballonia/Contracts/ITrainerService.cs index b84b1a83..33b2180a 100644 --- a/src/Baballonia/Contracts/ITrainerService.cs +++ b/src/Baballonia/Contracts/ITrainerService.cs @@ -7,7 +7,7 @@ namespace Baballonia.Contracts; public interface ITrainerService : IDisposable { public event Action? OnProgress; - public void RunTraining(string usercalbinPath, string outputfilePath); + public Task RunTraining(string usercalbinPath, string outputfilePath); public Task WaitAsync(); } diff --git a/src/Baballonia/Contracts/IVROverlay.cs b/src/Baballonia/Contracts/IVROverlay.cs index 076676b4..33df5324 100644 --- a/src/Baballonia/Contracts/IVROverlay.cs +++ b/src/Baballonia/Contracts/IVROverlay.cs @@ -1,10 +1,11 @@ using Baballonia.Helpers; using System; +using System.Collections.Generic; using System.Threading.Tasks; namespace Baballonia.Contracts; public interface IVROverlay : IDisposable { - public Task<(bool success, string status)> EyeTrackingCalibrationRequested(CalibrationRoutine.Routines calibrationRoutine); + public Task<(bool success, string status)> EyeTrackingCalibrationRequested(CalibrationRoutine.Routines calibrationRoutine, List args); } diff --git a/src/Baballonia/Converters/FileNameConverter.cs b/src/Baballonia/Converters/FileNameConverter.cs new file mode 100644 index 00000000..6676a3a0 --- /dev/null +++ b/src/Baballonia/Converters/FileNameConverter.cs @@ -0,0 +1,17 @@ +using Avalonia.Data.Converters; +using System; +using System.Globalization; +using System.IO; + +namespace Baballonia.Converters; + +public class FileNameConverter: IValueConverter +{ + public object? Convert(object? value, Type targetType, object? parameter, CultureInfo culture) + { + return Path.GetFileName((string?) value); + } + + object? IValueConverter.ConvertBack(object? value, Type targetType, object? parameter, CultureInfo culture) + => throw new NotSupportedException(); +} diff --git a/src/Baballonia/Helpers/CalibrationRoutine.cs b/src/Baballonia/Helpers/CalibrationRoutine.cs index c0e5999c..eebfd97e 100644 --- a/src/Baballonia/Helpers/CalibrationRoutine.cs +++ b/src/Baballonia/Helpers/CalibrationRoutine.cs @@ -10,8 +10,8 @@ public enum Routines { BasicCalibration, BasicCalibrationNoTutorial, - GazeOnly, - BlinkOnly, + TutorialStep, + TrainModel, } public static readonly Dictionary Map = Enum.GetValues().ToDictionary(i => i.ToString(), i => i); /* diff --git a/src/Baballonia/Services/CalibrationService.cs b/src/Baballonia/Services/CalibrationService.cs index 2695be4b..d6969cc8 100644 --- a/src/Baballonia/Services/CalibrationService.cs +++ b/src/Baballonia/Services/CalibrationService.cs @@ -20,13 +20,13 @@ public class CalibrationService : ICalibrationService { { "LeftEyeLid", "/LeftEyeLid" }, { "LeftEyeWiden", "/LeftEyeWiden" }, - // { "LeftEyeLower", "/LeftEyeLower" }, + { "LeftEyeLower", "/LeftEyeLower" }, { "LeftEyeBrow", "/LeftEyeBrow" }, { "RightEyeX", "/RightEyeX" }, { "RightEyeY", "/RightEyeY" }, { "RightEyeLid", "/RightEyeLid" }, { "RightEyeWiden", "/RightEyeWiden" }, - // { "RightEyeLower", "/RightEyeLower" }, + { "RightEyeLower", "/RightEyeLower" }, { "RightEyeBrow", "/RightEyeBrow" }, { "CheekPuffLeft", "/cheekPuffLeft" }, { "CheekPuffRight", "/cheekPuffRight" }, diff --git a/src/Baballonia/Services/EyePipelineManager.cs b/src/Baballonia/Services/EyePipelineManager.cs index 8d3c9e37..75a03382 100644 --- a/src/Baballonia/Services/EyePipelineManager.cs +++ b/src/Baballonia/Services/EyePipelineManager.cs @@ -54,11 +54,18 @@ public void InitializePipeline() public async Task LoadInferenceAsync() { - var inf = await Task.Run(CreateInference); + var eyeModelName = _localSettings.ReadSetting("EyeHome_EyeModel"); + if (eyeModelName != null && eyeModelName.EndsWith(".bin.gz", StringComparison.OrdinalIgnoreCase)) + { + await Task.Run(LoadRustInference); + return; + } + + var inf = await Task.Run(LoadOnnxInference); _pipeline.InferenceService = inf; } - private DefaultInferenceRunner CreateInference() + private DefaultInferenceRunner LoadOnnxInference() { const string defaultEyeModelName = "eyeModel.onnx"; var eyeModelName = _localSettings.ReadSetting("EyeHome_EyeModel", defaultEyeModelName); @@ -73,10 +80,23 @@ private DefaultInferenceRunner CreateInference() return _inferenceFactory.Create(eyeModelPath); } - - public void LoadInference() + private void LoadRustInference() { - _pipeline.InferenceService = CreateInference(); + var eyeModelName = _localSettings.ReadSetting("EyeHome_EyeModel"); + if (eyeModelName != null) + { + var eyeModelPath = Path.Combine(AppContext.BaseDirectory, eyeModelName); + + var load_error = _pipeline.LoadInference(eyeModelPath); + + if (load_error != null) + { + _logger.LogError($"Inference error: {load_error}"); + } else + { + _logger.LogInformation($"Inference loaded from {eyeModelPath}"); + } + } } public void LoadFilter() diff --git a/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs b/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs index 87706f44..db435bb2 100644 --- a/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs +++ b/src/Baballonia/Services/Inference/EyeProcessingPipeline.cs @@ -1,16 +1,62 @@ -using Baballonia.Services.events; +using System; +using Baballonia.Services.events; using Baballonia.Services.Inference.Enums; -using System; +using babble_model.Net.Sys; +using Microsoft.Extensions.Logging; +using OpenCvSharp; namespace Baballonia.Services.Inference; -public class EyeProcessingPipeline(IEyePipelineEventBus eyePipelineEventBus) : DefaultProcessingPipeline, IDisposable +public class EyeProcessingPipeline(ILogger logger, IEyePipelineEventBus eyePipelineEventBus) : DefaultProcessingPipeline, IDisposable { private readonly FastCorruptionDetector.FastCorruptionDetector _fastCorruptionDetector = new(); private readonly ImageCollector _imageCollector = new(); + private bool UseRustPipeline = false; public bool StabilizeEyes { get; set; } = true; + public unsafe string? LoadInference(string modelPath) + { + var bytes = System.Text.Encoding.UTF8.GetBytes(modelPath); + fixed (byte* ptr = bytes) + { + var output = NativeMethods.loadModel(ptr); + + if (output.is_error) { + var errorMsg = new string((sbyte*)output.value.error_message); + + NativeMethods.freeModelOutputResult(output); + + return errorMsg; + } + + UseRustPipeline = true; + + return null; + } + } + + unsafe float[]? RunInference(Mat collected) + { + var res = NativeMethods.infer(collected.ExtractChannel(0).DataPointer, collected.ExtractChannel(1).DataPointer); + + if (res.is_error) + { + string errorMsg = new String((sbyte*)res.value.error_message); + logger.LogError($"Inference error: {errorMsg}"); + NativeMethods.freeModelOutputResult(res); + return null; + } + + var output = res.value.model_output; + + float[] inferenceResult = { output.pitch_l, output.yaw_l, output.blink_l, output.eyebrow_l, output.eyewide_l, output.pitch_r, output.yaw_r, output.blink_r, output.eyebrow_r, output.eyewide_r }; + + NativeMethods.freeModelOutputResult(res); + + return inferenceResult; + } + public float[]? RunUpdate() { var frame = VideoSource?.GetFrame(ColorType.Gray8); @@ -33,15 +79,44 @@ public class EyeProcessingPipeline(IEyePipelineEventBus eyePipelineEventBus) : D if (collected == null) return null; - if (InferenceService == null) - return null; + float[]? inferenceResult; + + if (UseRustPipeline) + { + inferenceResult = RunInference(collected); + } + else + { + if (InferenceService == null) + return null; + + ImageConverter?.Convert(collected, InferenceService.GetInputTensor()); - ImageConverter?.Convert(collected, InferenceService.GetInputTensor()); + inferenceResult = InferenceService?.Run(); + } - var inferenceResult = InferenceService?.Run(); - if(inferenceResult == null) + if (inferenceResult == null) return null; + if (inferenceResult.Length == 6) + { + // Older model with only look and blink, without eyebrow and eye wide. + // We need to convert it to the new format by adding default values for the missing expressions. + inferenceResult = new float[] + { + inferenceResult[0], // left pitch + inferenceResult[1], // left yaw + inferenceResult[2], // left lid + 0, // left eyebrow (default value) + 0, // left eye wide (default value) + inferenceResult[3], // right pitch + inferenceResult[4], // right yaw + inferenceResult[5], // right lid + 0, // right eyebrow (default value) + 0 // right eye wide (default value) + }; + } + if (Filter != null) { inferenceResult = Filter.Filter(inferenceResult); @@ -68,10 +143,14 @@ private bool ProcessExpressions(ref float[] arKitExpressions) var leftPitch = arKitExpressions[0] * mulY - mulY / 2; var leftYaw = arKitExpressions[1] * mulV - mulV / 2; var leftLid = 1 - arKitExpressions[2]; + var leftEyebrow = arKitExpressions[3]; + var leftEyewide = arKitExpressions[4]; - var rightPitch = arKitExpressions[3] * mulY - mulY / 2; - var rightYaw = arKitExpressions[4] * mulV - mulV / 2; - var rightLid = 1 - arKitExpressions[5]; + var rightPitch = arKitExpressions[5] * mulY - mulY / 2; + var rightYaw = arKitExpressions[6] * mulV - mulV / 2; + var rightLid = 1 - arKitExpressions[7]; + var rightEyebrow = arKitExpressions[8]; + var rightEyewide = arKitExpressions[9]; var eyeY = (leftPitch * leftLid + rightPitch * rightLid) / (leftLid + rightLid); @@ -95,9 +174,13 @@ private bool ProcessExpressions(ref float[] arKitExpressions) convertedExpressions[0] = rightEyeYawCorrected; // left pitch convertedExpressions[1] = eyeY; // left yaw convertedExpressions[2] = rightLid; // left lid - convertedExpressions[3] = leftEyeYawCorrected; // right pitch - convertedExpressions[4] = eyeY; // right yaw - convertedExpressions[5] = leftLid; // right lid + convertedExpressions[3] = leftEyebrow; // left eyebrow + convertedExpressions[4] = leftEyewide; // left eye wide + convertedExpressions[5] = leftEyeYawCorrected; // right pitch + convertedExpressions[6] = eyeY; // right yaw + convertedExpressions[7] = leftLid; // right lid + convertedExpressions[8] = leftEyebrow; // right eyebrow + convertedExpressions[9] = leftEyewide; // right eye wide arKitExpressions = convertedExpressions; diff --git a/src/Baballonia/Services/ParameterSenderService.cs b/src/Baballonia/Services/ParameterSenderService.cs index ff56038f..f438d0a5 100644 --- a/src/Baballonia/Services/ParameterSenderService.cs +++ b/src/Baballonia/Services/ParameterSenderService.cs @@ -32,14 +32,14 @@ public class ParameterSenderService : BackgroundService { "LeftEyeX", "/LeftEyeX" }, { "LeftEyeY", "/LeftEyeY" }, { "LeftEyeLid", "/LeftEyeLid" }, - //{ "LeftEyeWiden", "/LeftEyeWiden" }, - //{ "LeftEyeLower", "/LeftEyeLower" }, + { "LeftEyeLower", "/LeftEyeLower" }, + { "LeftEyeWiden", "/LeftEyeWiden" }, //{ "LeftEyeBrow", "/LeftEyeBrow" }, { "RightEyeX", "/RightEyeX" }, { "RightEyeY", "/RightEyeY" }, { "RightEyeLid", "/RightEyeLid" }, - //{ "RightEyeWiden", "/RightEyeWiden" }, - //{ "RightEyeLower", "/RightEyeLower" }, + { "RightEyeLower", "/RightEyeLower" }, + { "RightEyeWiden", "/RightEyeWiden" }, //{ "RightEyeBrow", "/RightEyeBrow" }, }; diff --git a/src/Baballonia/Utils/Utils.cs b/src/Baballonia/Utils/Utils.cs index 3e88f32f..b26af504 100644 --- a/src/Baballonia/Utils/Utils.cs +++ b/src/Baballonia/Utils/Utils.cs @@ -12,7 +12,7 @@ namespace Baballonia; public static class Utils { - public const int EyeRawExpressions = 6; + public const int EyeRawExpressions = 10; public const int FaceRawExpressions = 45; public const int FramesForEyeInference = 4; @@ -63,6 +63,10 @@ public static class Utils ? Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "ProjectBabble", "ModelData") : AppContext.BaseDirectory; + public static readonly string ModelTrainingDataDirectory = IsSupportedDesktopOS + ? Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "ProjectBabble", "ModelTrainingData") + : AppContext.BaseDirectory; + public static readonly string VrcftLibsDirectory = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "VRCFaceTracking", "CustomLibs"); diff --git a/src/Baballonia/ViewLocator.cs b/src/Baballonia/ViewLocator.cs index cbfaefdc..327bc9d6 100644 --- a/src/Baballonia/ViewLocator.cs +++ b/src/Baballonia/ViewLocator.cs @@ -20,6 +20,7 @@ public ViewLocator() RegisterViewFactory(); RegisterViewFactory(); RegisterViewFactory(); + RegisterViewFactory(); RegisterViewFactory(); RegisterViewFactory(); diff --git a/src/Baballonia/ViewModels/MainViewModel.cs b/src/Baballonia/ViewModels/MainViewModel.cs index 952157ff..121f6d8b 100644 --- a/src/Baballonia/ViewModels/MainViewModel.cs +++ b/src/Baballonia/ViewModels/MainViewModel.cs @@ -41,6 +41,7 @@ private void SetOverlay(bool show) new(typeof(CalibrationViewModel), "EditRegular", Resources.Calibration_Title_Header), // Calibration new(typeof(FirmwareViewModel), "DeveloperBoardRegular", Resources.Firmware_Title_Header), // Firmware new(typeof(VrcViewModel), "CommentRegular", "VRChat"), // VRChat. No translation :P + new(typeof(EyeTrainingViewModel), "HeadsetVrRegular", "Eye training management"), // Eye Training new(typeof(OutputPageViewModel), "TextFirstLineRegular", Resources.Output_Title_Header), // Output new(typeof(AppSettingsViewModel), "SettingsRegular", Resources.Settings_Title_Header), // Settings new(typeof(AboutPageViewModel), "InfoRegular", Resources.About_Title_Header), // About diff --git a/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs b/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs index b7394e49..5e3b505c 100644 --- a/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs +++ b/src/Baballonia/ViewModels/SplitViewPane/CalibrationViewModel.cs @@ -42,12 +42,12 @@ public CalibrationViewModel(EyePipelineManager eyePipelineManager) [ new("LeftEyeLid"), new("RightEyeLid"), - //new ("LeftEyeWiden"), + new ("LeftEyeWiden"), //new ("LeftEyeLower"), - //new ("LeftEyeBrow"), - //new ("RightEyeWiden"), + new ("LeftEyeBrow"), + new ("RightEyeWiden"), //new ("RightEyeLower"), - //new ("RightEyeBrow"), + new ("RightEyeBrow"), ]; JawSettings = @@ -127,14 +127,14 @@ public CalibrationViewModel(EyePipelineManager eyePipelineManager) { "LeftEyeX", 0 }, { "LeftEyeY", 1 }, { "LeftEyeLid", 2 }, - //{ "LeftEyeWiden", }, - //{ "LeftEyeLower", }, + { "LeftEyeLower", 3 }, + { "LeftEyeWiden", 4 }, //{ "LeftEyeBrow", }, - { "RightEyeX", 3 }, - { "RightEyeY", 4 }, - { "RightEyeLid", 5 }, - //{ "RightEyeWiden", }, - //{ "RightEyeLower", }, + { "RightEyeX", 5 }, + { "RightEyeY", 6 }, + { "RightEyeLid", 7 }, + { "RightEyeLower", 8 }, + { "RightEyeWiden", 9 }, //{ "RightEyeBrow", }, }; diff --git a/src/Baballonia/ViewModels/SplitViewPane/EyeTrainingViewModel.cs b/src/Baballonia/ViewModels/SplitViewPane/EyeTrainingViewModel.cs new file mode 100644 index 00000000..2154d964 --- /dev/null +++ b/src/Baballonia/ViewModels/SplitViewPane/EyeTrainingViewModel.cs @@ -0,0 +1,147 @@ +using Avalonia.Controls; +using Avalonia.Media; +using Baballonia.Contracts; +using Baballonia.Helpers; +using Baballonia.Services; +using CommunityToolkit.Mvvm.ComponentModel; +using CommunityToolkit.Mvvm.Input; +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.IO; +using System.Linq; +using System.Threading.Tasks; + +namespace Baballonia.ViewModels.SplitViewPane; + +public partial class EyeTrainingViewModel : ViewModelBase +{ + public partial class CalibrationStep : ObservableObject + { + [ObservableProperty] private string _title; + [ObservableProperty] ObservableCollection _files; + [ObservableProperty] string _selectedFile; + + private readonly string _stepName; + private readonly IVROverlay _vrOverlay; + + public CalibrationStep(IVROverlay vrOverlay, string title, string name) + { + _title = title; + _stepName = name; + _files = new ObservableCollection(); + _vrOverlay = vrOverlay; + + GetExistingFiles(); + + if (_files.Count > 0) + _selectedFile = _files[0]; + } + + public void GetExistingFiles() + { + _files.Clear(); + var calibrationFolder = Path.Combine(Utils.ModelTrainingDataDirectory, _stepName); + if (!Directory.Exists(calibrationFolder)) + { + return; + } + + foreach (var file in Directory.GetFiles(calibrationFolder, "*.bin", SearchOption.AllDirectories)) + { + _files.Add(file); + } + } + + [RelayCommand] + private async Task ReRecord() + { + var res = await Task.Run(async () => + { + try + { + return await _vrOverlay.EyeTrackingCalibrationRequested(CalibrationRoutine.Routines.TutorialStep, new List + { + _stepName, + Path.Combine(Utils.ModelTrainingDataDirectory, _stepName, $"{DateTime.Now:yyyyMMdd_HHmmss}.bin") + }); + } + catch (Exception ex) + { + return (false, ex.Message); + } + } + ); + + GetExistingFiles(); + //if (res.Item1) + //{ + // SelectedCalibrationTextBlock.Foreground = new SolidColorBrush(Colors.Green); + //} + //else + //{ + // SelectedCalibrationTextBlock.Foreground = new SolidColorBrush(Colors.Red); + // _logger.LogError(res.Item2); + //} + } + } + + public ObservableCollection CalibrationSteps { get; set; } + + private readonly IVROverlay _vrOverlay; + private readonly ILogger _logger; + private readonly EyePipelineEventBus _eyePipelineEventBus; + + public Button RetrainButton { get; set; } + + public EyeTrainingViewModel( + IVROverlay vrOverlay, + ILogger logger, + EyePipelineEventBus eyePipelineEventBus) + { + _vrOverlay = vrOverlay; + _logger = logger; + _eyePipelineEventBus = eyePipelineEventBus; + + CalibrationSteps = new ObservableCollection + { + new CalibrationStep(_vrOverlay, "Gaze", "gaze"), + new CalibrationStep(_vrOverlay, "Blink", "blink"), + new CalibrationStep(_vrOverlay, "Eyebrows", "brow"), + new CalibrationStep(_vrOverlay, "Squinting", "squint"), + new CalibrationStep(_vrOverlay, "Widening", "widen"), + }; + } + + + [RelayCommand] + private async Task RetrainModel() + { + + var paths = CalibrationSteps.Select(cs => cs.SelectedFile).ToList(); + + var res = await Task.Run(async () => + { + try + { + return await _vrOverlay.EyeTrackingCalibrationRequested(CalibrationRoutine.Routines.TrainModel, paths); + } + catch (Exception ex) + { + return (false, ex.Message); + } + } + ); + if (res.Item1) + { + RetrainButton.Foreground = new SolidColorBrush(Colors.Green); + } + else + { + RetrainButton.Foreground = new SolidColorBrush(Colors.Red); + _logger.LogError(res.Item2); + } + } + +} diff --git a/src/Baballonia/ViewModels/SplitViewPane/HomePageViewModel.cs b/src/Baballonia/ViewModels/SplitViewPane/HomePageViewModel.cs index 062df520..52154b71 100644 --- a/src/Baballonia/ViewModels/SplitViewPane/HomePageViewModel.cs +++ b/src/Baballonia/ViewModels/SplitViewPane/HomePageViewModel.cs @@ -16,6 +16,7 @@ using Microsoft.Extensions.Logging; using OpenCvSharp; using System; +using System.Collections.Generic; using System.Collections.ObjectModel; using System.ComponentModel; using System.Linq; @@ -251,21 +252,21 @@ private void EyeImageUpdateHandler(Mat image) switch (Camera) { case Camera.Left: - { - var leftHalf = new OpenCvSharp.Rect(0, 0, width / 2, height); - var leftRoi = new Mat(image, leftHalf); + { + var leftHalf = new OpenCvSharp.Rect(0, 0, width / 2, height); + var leftRoi = new Mat(image, leftHalf); - UpdateBitmap(leftRoi); - break; - } + UpdateBitmap(leftRoi); + break; + } case Camera.Right: - { - var rightHalf = new OpenCvSharp.Rect(width / 2, 0, width / 2, height); - var rightRoi = new Mat(image, rightHalf); - UpdateBitmap(rightRoi); - break; - } + { + var rightHalf = new OpenCvSharp.Rect(width / 2, 0, width / 2, height); + var rightRoi = new Mat(image, rightHalf); + UpdateBitmap(rightRoi); + break; + } } } else if (channels == 2) @@ -608,7 +609,7 @@ public async Task StartCamera(CameraControllerModel model) await StartCameraWithMaximization(model, startMaximized: true); } - private async Task StartCameraWithMaximization(CameraControllerModel model, bool startMaximized) + private async Task StartCameraWithMaximization(CameraControllerModel model, bool startMaximized) { try { @@ -682,7 +683,7 @@ private async Task RequestVRCalibration() { try { - return await _vrOverlay.EyeTrackingCalibrationRequested(RequestedVRCalibration); + return await _vrOverlay.EyeTrackingCalibrationRequested(RequestedVRCalibration, new List { }); } catch (Exception ex) { diff --git a/src/Baballonia/Views/EyeTrainingView.axaml b/src/Baballonia/Views/EyeTrainingView.axaml new file mode 100644 index 00000000..68554274 --- /dev/null +++ b/src/Baballonia/Views/EyeTrainingView.axaml @@ -0,0 +1,114 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Baballonia/Views/EyeTrainingView.axaml.cs b/src/Baballonia/Views/EyeTrainingView.axaml.cs new file mode 100644 index 00000000..dc929824 --- /dev/null +++ b/src/Baballonia/Views/EyeTrainingView.axaml.cs @@ -0,0 +1,34 @@ +using Avalonia.Controls; +using Avalonia.Interactivity; +using Baballonia.ViewModels.SplitViewPane; + +namespace Baballonia.Views; + +public partial class EyeTrainingView : ViewBase +{ + public EyeTrainingView() + { + InitializeComponent(); + + Loaded += (_, _) => + { + if (DataContext is not EyeTrainingViewModel vm) return; + + vm.RetrainButton = this.Find