diff --git a/src/Baballonia.Desktop/Baballonia.Desktop.csproj b/src/Baballonia.Desktop/Baballonia.Desktop.csproj index 3f351ec2..8cde76a6 100644 --- a/src/Baballonia.Desktop/Baballonia.Desktop.csproj +++ b/src/Baballonia.Desktop/Baballonia.Desktop.csproj @@ -33,7 +33,8 @@ - + + @@ -48,7 +49,7 @@ - + diff --git a/src/Baballonia/Services/DefaultInferenceRunner.cs b/src/Baballonia/Services/DefaultInferenceRunner.cs index 1c478829..6633e4ed 100644 --- a/src/Baballonia/Services/DefaultInferenceRunner.cs +++ b/src/Baballonia/Services/DefaultInferenceRunner.cs @@ -113,10 +113,24 @@ private void ConfigurePlatformSpecificGpu(SessionOptions sessionOptions, string // If the user's system does not support DirectML (for whatever reason, // it's shipped with Windows 10, version 1903(10.0; Build 18362)+ - // Fallback on good ol' CUDA + // Fallback on CUDA (also, if you're on Linux) + OrtCUDAProviderOptions cudaProviderOptions = null!; try { - sessionOptions.AppendExecutionProvider_CUDA(); + cudaProviderOptions = new OrtCUDAProviderOptions(); + var providerOptionsDict = new Dictionary + { + ["device_id"] = "0", + ["gpu_mem_limit"] = "2147483648", + // Overkill options + // ["arena_extend_strategy"] = "kSameAsRequested", + // ["cudnn_conv_algo_search"] = "DEFAULT", + // ["do_copy_in_default_stream"] = "1", + // ["cudnn_conv_use_max_workspace"] = "1", + // ["cudnn_conv1d_pad_to_nc1d"] = "1" + }; + cudaProviderOptions.UpdateOptions(providerOptionsDict); + sessionOptions = SessionOptions.MakeSessionOptionWithCudaProvider(cudaProviderOptions); _logger.LogInformation("Initialized ExecutionProvider: CUDA for {ModelName}", modelName); return; } @@ -124,20 +138,13 @@ private void ConfigurePlatformSpecificGpu(SessionOptions sessionOptions, string { _logger.LogWarning("Failed to create CUDA Execution Provider."); } - - // And, if CUDA fails (or we have an AMD card) - // Try one more time with MiGraphX/ROCm - try + finally { - sessionOptions.AppendExecutionProvider_ROCm(); - _logger.LogInformation("Initialized ExecutionProvider: ROCm for {ModelName}", modelName); - return; - } - catch (Exception) - { - _logger.LogWarning("Failed to create ROCm Execution Provider."); + cudaProviderOptions.Dispose(); } + // And, if CUDA fails (or we have an AMD card) + // Try one more time with MiGraphX try { sessionOptions.AppendExecutionProvider_MIGraphX();