diff --git a/src/Baballonia.Desktop/Baballonia.Desktop.csproj b/src/Baballonia.Desktop/Baballonia.Desktop.csproj
index 3f351ec2..8cde76a6 100644
--- a/src/Baballonia.Desktop/Baballonia.Desktop.csproj
+++ b/src/Baballonia.Desktop/Baballonia.Desktop.csproj
@@ -33,7 +33,8 @@
-
+
+
@@ -48,7 +49,7 @@
-
+
diff --git a/src/Baballonia/Services/DefaultInferenceRunner.cs b/src/Baballonia/Services/DefaultInferenceRunner.cs
index 1c478829..6633e4ed 100644
--- a/src/Baballonia/Services/DefaultInferenceRunner.cs
+++ b/src/Baballonia/Services/DefaultInferenceRunner.cs
@@ -113,10 +113,24 @@ private void ConfigurePlatformSpecificGpu(SessionOptions sessionOptions, string
// If the user's system does not support DirectML (for whatever reason,
// it's shipped with Windows 10, version 1903(10.0; Build 18362)+
- // Fallback on good ol' CUDA
+ // Fallback on CUDA (also, if you're on Linux)
+ OrtCUDAProviderOptions cudaProviderOptions = null!;
try
{
- sessionOptions.AppendExecutionProvider_CUDA();
+ cudaProviderOptions = new OrtCUDAProviderOptions();
+ var providerOptionsDict = new Dictionary
+ {
+ ["device_id"] = "0",
+ ["gpu_mem_limit"] = "2147483648",
+ // Overkill options
+ // ["arena_extend_strategy"] = "kSameAsRequested",
+ // ["cudnn_conv_algo_search"] = "DEFAULT",
+ // ["do_copy_in_default_stream"] = "1",
+ // ["cudnn_conv_use_max_workspace"] = "1",
+ // ["cudnn_conv1d_pad_to_nc1d"] = "1"
+ };
+ cudaProviderOptions.UpdateOptions(providerOptionsDict);
+ sessionOptions = SessionOptions.MakeSessionOptionWithCudaProvider(cudaProviderOptions);
_logger.LogInformation("Initialized ExecutionProvider: CUDA for {ModelName}", modelName);
return;
}
@@ -124,20 +138,13 @@ private void ConfigurePlatformSpecificGpu(SessionOptions sessionOptions, string
{
_logger.LogWarning("Failed to create CUDA Execution Provider.");
}
-
- // And, if CUDA fails (or we have an AMD card)
- // Try one more time with MiGraphX/ROCm
- try
+ finally
{
- sessionOptions.AppendExecutionProvider_ROCm();
- _logger.LogInformation("Initialized ExecutionProvider: ROCm for {ModelName}", modelName);
- return;
- }
- catch (Exception)
- {
- _logger.LogWarning("Failed to create ROCm Execution Provider.");
+ cudaProviderOptions.Dispose();
}
+ // And, if CUDA fails (or we have an AMD card)
+ // Try one more time with MiGraphX
try
{
sessionOptions.AppendExecutionProvider_MIGraphX();