diff --git a/scripts/performance/common.py b/scripts/performance/common.py index 0be3b93fc2a..5d96d043773 100644 --- a/scripts/performance/common.py +++ b/scripts/performance/common.py @@ -210,24 +210,16 @@ def retry_on_exception( def get_certificates() -> list[str]: ''' - Gets the certificates from the certhelper tool and on Mac uses find-certificate. + Gets the certificates from the certhelper tool. ''' - if ismac(): - certs: list[str] = [] - with open("/Users/helix-runner/certs/LabCert1.pfx", "rb") as f: - certs.append(base64.b64encode(f.read()).decode()) - with open("/Users/helix-runner/certs/LabCert2.pfx", "rb") as f: - certs.append(base64.b64encode(f.read()).decode()) - return certs - else: - cmd_line = [(os.path.join(str(helixpayload()), 'certhelper', "CertHelper%s" % extension()))] - cert_helper = RunCommand(cmd_line, None, True, False, 0) - try: - return cert_helper.run_and_get_stdout().splitlines() - except Exception as ex: - getLogger().error("Failed to get certificates") - getLogger().error('{0}: {1}'.format(type(ex), str(ex))) - return [] + cmd_line = [(os.path.join(str(helixpayload()), 'certhelper', "CertHelper%s" % extension()))] + cert_helper = RunCommand(cmd_line, None, True, False, 0) + try: + return cert_helper.run_and_get_stdout().splitlines() + except Exception as ex: + getLogger().error("Failed to get certificates") + getLogger().error('{0}: {1}'.format(type(ex), str(ex))) + return [] def __write_pipeline_variable(name: str, value: str): diff --git a/src/tools/CertHelper/KeyVaultCert.cs b/src/tools/CertHelper/KeyVaultCert.cs index e6fd7d73630..7b791af1206 100644 --- a/src/tools/CertHelper/KeyVaultCert.cs +++ b/src/tools/CertHelper/KeyVaultCert.cs @@ -8,6 +8,7 @@ using System.IO; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading.Tasks; @@ -21,6 +22,7 @@ public class KeyVaultCert private readonly string _clientId = "8c4b65ef-5a73-4d5a-a298-962d4a4ef7bc"; public X509Certificate2Collection KeyVaultCertificates { get; set; } + public List KeyVaultCertificateBytes { get; set; } public ILocalCert LocalCerts { get; set; } private TokenCredential _credential { get; set; } private CertificateClient _certClient { get; set; } @@ -33,16 +35,28 @@ public KeyVaultCert(TokenCredential? cred = null, CertificateClient? certClient _certClient = certClient ?? new CertificateClient(new Uri(_keyVaultUrl), _credential); _secretClient = secretClient ?? new SecretClient(new Uri(_keyVaultUrl), _credential); KeyVaultCertificates = new X509Certificate2Collection(); + KeyVaultCertificateBytes = new List(); } - public async Task LoadKeyVaultCertsAsync() + public async Task LoadKeyVaultCertsAsync(bool? rawBytesOnly = null) { - KeyVaultCertificates.Add(await FindCertificateInKeyVaultAsync(Constants.Cert1Name)); - KeyVaultCertificates.Add(await FindCertificateInKeyVaultAsync(Constants.Cert2Name)); + bool skipX509Load = rawBytesOnly ?? RuntimeInformation.IsOSPlatform(OSPlatform.OSX); - if (KeyVaultCertificates.Where(c => c == null).Count() > 0) + var (cert1, bytes1) = await FindCertificateInKeyVaultAsync(Constants.Cert1Name, skipX509Load); + var (cert2, bytes2) = await FindCertificateInKeyVaultAsync(Constants.Cert2Name, skipX509Load); + + KeyVaultCertificateBytes.Add(bytes1); + KeyVaultCertificateBytes.Add(bytes2); + + if (!skipX509Load) { - throw new Exception("One or more certificates not found"); + KeyVaultCertificates.Add(cert1!); + KeyVaultCertificates.Add(cert2!); + + if (KeyVaultCertificates.Where(c => c == null).Count() > 0) + { + throw new Exception("One or more certificates not found"); + } } } @@ -136,7 +150,7 @@ private async Task GetCertificateCredentialAsync(st return ccc; } - private async Task FindCertificateInKeyVaultAsync(string certName) + private async Task<(X509Certificate2?, byte[])> FindCertificateInKeyVaultAsync(string certName, bool rawBytesOnly = false) { var keyVaultCert = await _certClient.GetCertificateAsync(certName); if(keyVaultCert.Value == null) @@ -149,12 +163,18 @@ private async Task FindCertificateInKeyVaultAsync(string certN throw new Exception("Certificate secret not found in Key Vault"); } var certBytes = Convert.FromBase64String(secret.Value.Value); + + if (rawBytesOnly) + { + return (null, certBytes); + } + #if NET9_0_OR_GREATER var cert = X509CertificateLoader.LoadPkcs12(certBytes, "", X509KeyStorageFlags.Exportable | X509KeyStorageFlags.PersistKeySet); #else var cert = new X509Certificate2(certBytes, "", X509KeyStorageFlags.Exportable | X509KeyStorageFlags.PersistKeySet); #endif - return cert; + return (cert, certBytes); } public bool ShouldRotateCerts() diff --git a/src/tools/CertHelper/LocalCert.cs b/src/tools/CertHelper/LocalCert.cs index 5e0dbe80b35..45cbba9794f 100644 --- a/src/tools/CertHelper/LocalCert.cs +++ b/src/tools/CertHelper/LocalCert.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading.Tasks; @@ -13,19 +14,29 @@ public class LocalCert : ILocalCert { public X509Certificate2Collection Certificates { get; set; } public bool RequiresBootstrap { get; private set; } - internal IX509Store LocalMachineCerts { get; set; } + internal IX509Store? LocalMachineCerts { get; set; } public LocalCert(IX509Store? store = null) { - LocalMachineCerts = store ?? new TestableX509Store(); Certificates = new X509Certificate2Collection(); RequiresBootstrap = false; - GetLocalCerts(); + + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + // Skip Keychain access on macOS to avoid password prompts. + // Certs are managed as files on disk instead. + RequiresBootstrap = true; + } + else + { + LocalMachineCerts = store ?? new TestableX509Store(); + GetLocalCerts(); + } } private void GetLocalCerts() { - foreach (var cert in LocalMachineCerts.Certificates.Find(X509FindType.FindBySubjectName, "dotnetperf.microsoft.com", false)) + foreach (var cert in LocalMachineCerts!.Certificates.Find(X509FindType.FindBySubjectName, "dotnetperf.microsoft.com", false)) { if (cert.Subject == "CN=dotnetperf.microsoft.com") { diff --git a/src/tools/CertHelper/Program.cs b/src/tools/CertHelper/Program.cs index 2e717a7a119..c0828601540 100644 --- a/src/tools/CertHelper/Program.cs +++ b/src/tools/CertHelper/Program.cs @@ -21,31 +21,42 @@ static async Task Main(string[] args) await kvc.LoadKeyVaultCertsAsync(); if (kvc.ShouldRotateCerts()) { - using (var localMachineCerts = new X509Store(StoreName.My, StoreLocation.CurrentUser)) + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - localMachineCerts.Open(OpenFlags.ReadWrite); - localMachineCerts.RemoveRange(kvc.LocalCerts.Certificates); - localMachineCerts.AddRange(kvc.KeyVaultCertificates); + WriteCertsToDisk(kvc.KeyVaultCertificateBytes); + } + else + { + using (var localMachineCerts = new X509Store(StoreName.My, StoreLocation.CurrentUser)) + { + localMachineCerts.Open(OpenFlags.ReadWrite); + localMachineCerts.RemoveRange(kvc.LocalCerts.Certificates); + localMachineCerts.AddRange(kvc.KeyVaultCertificates); + } } } - var bcc = new BlobContainerClient(new Uri("https://pvscmdupload.blob.core.windows.net/certstatus"), - new ClientCertificateCredential(TENANT_ID, CERT_CLIENT_ID, kvc.KeyVaultCertificates.First(), new() {SendCertificateChain = true})); - var currentKeyValutCertThumbprints = ""; - foreach (var cert in kvc.KeyVaultCertificates) - { - currentKeyValutCertThumbprints += $"[{DateTimeOffset.UtcNow}] {cert.Thumbprint}{Environment.NewLine}"; - } - var blob = bcc.GetBlobClient(System.Environment.MachineName); - if (blob.Exists()) - { - var result = blob.DownloadContent(); - var currentBlob = result.Value.Content.ToString(); - currentBlob = currentBlob + currentKeyValutCertThumbprints; - blob.Upload(new MemoryStream(Encoding.UTF8.GetBytes(currentBlob)), overwrite: true); - } - else + + if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - blob.Upload(new MemoryStream(Encoding.UTF8.GetBytes(currentKeyValutCertThumbprints)), overwrite: false); + var bcc = new BlobContainerClient(new Uri("https://pvscmdupload.blob.core.windows.net/certstatus"), + new ClientCertificateCredential(TENANT_ID, CERT_CLIENT_ID, kvc.KeyVaultCertificates.First(), new() {SendCertificateChain = true})); + var currentKeyValutCertThumbprints = ""; + foreach (var cert in kvc.KeyVaultCertificates) + { + currentKeyValutCertThumbprints += $"[{DateTimeOffset.UtcNow}] {cert.Thumbprint}{Environment.NewLine}"; + } + var blob = bcc.GetBlobClient(System.Environment.MachineName); + if (blob.Exists()) + { + var result = blob.DownloadContent(); + var currentBlob = result.Value.Content.ToString(); + currentBlob = currentBlob + currentKeyValutCertThumbprints; + blob.Upload(new MemoryStream(Encoding.UTF8.GetBytes(currentBlob)), overwrite: true); + } + else + { + blob.Upload(new MemoryStream(Encoding.UTF8.GetBytes(currentKeyValutCertThumbprints)), overwrite: false); + } } } catch (Exception ex) @@ -55,13 +66,57 @@ static async Task Main(string[] args) Console.Error.WriteLine(ex.StackTrace); } - using (var store = new X509Store(StoreName.My, StoreLocation.CurrentUser, OpenFlags.ReadWrite)) + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { - foreach(var cert in store.Certificates.Find(X509FindType.FindBySubjectName, "dotnetperf.microsoft.com", false)) + ReadCertsFromDisk(); + } + else + { + using (var store = new X509Store(StoreName.My, StoreLocation.CurrentUser, OpenFlags.ReadWrite)) { - Console.WriteLine(Convert.ToBase64String(cert.Export(X509ContentType.Pfx))); + foreach(var cert in store.Certificates.Find(X509FindType.FindBySubjectName, "dotnetperf.microsoft.com", false)) + { + Console.WriteLine(Convert.ToBase64String(cert.Export(X509ContentType.Pfx))); + } } } return 0; } + + static string GetMacCertDirectory() + { + var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + return Path.Combine(home, "certs"); + } + + static void WriteCertsToDisk(List certBytes) + { + var certDir = GetMacCertDirectory(); + Directory.CreateDirectory(certDir); + + var certNames = new[] { Constants.Cert1Name, Constants.Cert2Name }; + for (int i = 0; i < certBytes.Count && i < certNames.Length; i++) + { + var pfxPath = Path.Combine(certDir, $"{certNames[i]}.pfx"); + File.WriteAllBytes(pfxPath, certBytes[i]); + Console.Error.WriteLine($"Wrote certificate to {pfxPath}"); + } + } + + static void ReadCertsFromDisk() + { + var certDir = GetMacCertDirectory(); + foreach (var certName in new[] { Constants.Cert1Name, Constants.Cert2Name }) + { + var pfxPath = Path.Combine(certDir, $"{certName}.pfx"); + if (File.Exists(pfxPath)) + { + Console.WriteLine(Convert.ToBase64String(File.ReadAllBytes(pfxPath))); + } + else + { + Console.Error.WriteLine($"Certificate file not found: {pfxPath}"); + } + } + } } diff --git a/src/tools/CertHelperTests/KeyVaultCertTests.cs b/src/tools/CertHelperTests/KeyVaultCertTests.cs index e7ee26df040..8a129f7e624 100644 --- a/src/tools/CertHelperTests/KeyVaultCertTests.cs +++ b/src/tools/CertHelperTests/KeyVaultCertTests.cs @@ -212,5 +212,96 @@ public async Task ShouldRotateCerts_ShouldReturnTrue_WhenNoLocalCertsExist() // Assert Assert.True(result); } + + [Fact] + public async Task LoadKeyVaultCertsAsync_RawBytesOnly_ShouldPopulateBytesButNotCertificates() + { + // Arrange + Mock mockTokenCred; + Mock mockCertClient; + Mock mockSecretClient; + Mock mockLocalCert; + CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert); + + var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object); + + // Act + await keyVaultCert.LoadKeyVaultCertsAsync(rawBytesOnly: true); + + // Assert + Assert.Equal(2, keyVaultCert.KeyVaultCertificateBytes.Count); + Assert.True(keyVaultCert.KeyVaultCertificateBytes[0].Length > 0); + Assert.True(keyVaultCert.KeyVaultCertificateBytes[1].Length > 0); + Assert.Empty(keyVaultCert.KeyVaultCertificates); + } + + [Fact] + public async Task LoadKeyVaultCertsAsync_RawBytesOnly_BytesShouldBeValidPfx() + { + // Arrange + Mock mockTokenCred; + Mock mockCertClient; + Mock mockSecretClient; + Mock mockLocalCert; + CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert); + + var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object); + + // Act + await keyVaultCert.LoadKeyVaultCertsAsync(rawBytesOnly: true); + + // Assert - bytes should be loadable as PFX + foreach (var certBytes in keyVaultCert.KeyVaultCertificateBytes) + { + var cert = X509CertificateLoader.LoadPkcs12(certBytes, "", X509KeyStorageFlags.DefaultKeySet); + Assert.NotNull(cert); + Assert.False(string.IsNullOrEmpty(cert.Thumbprint)); + } + } + + [Fact] + public async Task LoadKeyVaultCertsAsync_Default_ShouldPopulateBothBytesAndCertificates() + { + // Arrange + Mock mockTokenCred; + Mock mockCertClient; + Mock mockSecretClient; + Mock mockLocalCert; + CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert); + + var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object); + + // Act + await keyVaultCert.LoadKeyVaultCertsAsync(rawBytesOnly: false); + + // Assert + Assert.Equal(2, keyVaultCert.KeyVaultCertificateBytes.Count); + Assert.Equal(2, keyVaultCert.KeyVaultCertificates.Count); + } + + [Fact] + public async Task ShouldRotateCerts_ShouldReturnTrue_WhenBootstrapRequired() + { + // Arrange - simulates macOS scenario where LocalCert skips Keychain + Mock mockTokenCred; + Mock mockCertClient; + Mock mockSecretClient; + Mock mockLocalCert; + CertStoreSetup(out mockTokenCred, out mockCertClient, out mockSecretClient, out mockLocalCert); + + mockLocalCert.Setup(lc => lc.Certificates).Returns(new X509Certificate2Collection()); + mockLocalCert.Setup(lc => lc.RequiresBootstrap).Returns(true); + + var keyVaultCert = new KeyVaultCert(mockTokenCred.Object, mockCertClient.Object, mockSecretClient.Object, mockLocalCert.Object); + + // Act + await keyVaultCert.LoadKeyVaultCertsAsync(rawBytesOnly: true); + var result = keyVaultCert.ShouldRotateCerts(); + + // Assert + Assert.True(result); + Assert.Equal(2, keyVaultCert.KeyVaultCertificateBytes.Count); + Assert.Empty(keyVaultCert.KeyVaultCertificates); + } }