Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ final class SafeScriptMessageHandler: NSObject, WKScriptMessageHandler {
_ userContentController: WKUserContentController,
didReceive message: WKScriptMessage
) {
// Only the top-level document on an allowed server host may talk to the native bridge.
// Only the top-level document on an allowed server origin may talk to the native bridge.
guard shouldAllowMessage(
isMainFrame: message.frameInfo.isMainFrame,
host: message.frameInfo.securityOrigin.host
host: message.frameInfo.securityOrigin.host,
port: message.frameInfo.securityOrigin.port
) else {
return
}
Expand All @@ -28,17 +29,44 @@ final class SafeScriptMessageHandler: NSObject, WKScriptMessageHandler {
)
}

func shouldAllowMessage(isMainFrame: Bool, host: String) -> Bool {
isMainFrame && allowedHosts.contains(host)
func shouldAllowMessage(isMainFrame: Bool, host: String, port: Int) -> Bool {
isMainFrame && allowedOrigins.contains(originKey(host: host, port: port))
}

private var allowedHosts: Set<String> {
private var allowedOrigins: Set<String> {
let urls = [
server.info.connection.address(for: .internal),
server.info.connection.address(for: .external),
server.info.connection.address(for: .remoteUI),
]

return Set(urls.compactMap { $0?.host })
return Set(urls.compactMap(originKey(url:)))
}

private func originKey(url: URL?) -> String? {
guard let url, let host = url.host, let port = normalizedPort(for: url) else {
return nil
}

return originKey(host: host, port: port)
}

private func originKey(host: String, port: Int) -> String {
"\(host):\(port)"
}

private func normalizedPort(for url: URL) -> Int? {
if let port = url.port {
return port
}

switch url.scheme?.lowercased() {
case "http":
return 80
case "https":
return 443
default:
return nil
}
}
}
13 changes: 7 additions & 6 deletions Tests/App/WebView/SafeScriptMessageHandlerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@ struct SafeScriptMessageHandlerTests {
delegate: NoOpScriptMessageHandler()
)

#expect(handler.shouldAllowMessage(isMainFrame: true, host: "external.example.com"))
#expect(handler.shouldAllowMessage(isMainFrame: true, host: "internal.example.com"))
#expect(handler.shouldAllowMessage(isMainFrame: true, host: "ui.nabu.casa"))
#expect(handler.shouldAllowMessage(isMainFrame: true, host: "external.example.com", port: 443))
#expect(handler.shouldAllowMessage(isMainFrame: true, host: "internal.example.com", port: 80))
#expect(handler.shouldAllowMessage(isMainFrame: true, host: "ui.nabu.casa", port: 443))
}

@Test func rejectsMessageFromHostOutsideConfiguredServerHosts() {
@Test func rejectsMessageFromOriginOutsideConfiguredServerOrigins() {
ServerFixture.reset()
let handler = SafeScriptMessageHandler(
server: ServerFixture.withRemoteConnection,
delegate: NoOpScriptMessageHandler()
)

#expect(!handler.shouldAllowMessage(isMainFrame: true, host: "evil.example.com"))
#expect(!handler.shouldAllowMessage(isMainFrame: true, host: "evil.example.com", port: 443))
#expect(!handler.shouldAllowMessage(isMainFrame: true, host: "external.example.com", port: 8123))
}

@Test func rejectsIframeMessageEvenWhenHostIsAllowed() {
Expand All @@ -33,7 +34,7 @@ struct SafeScriptMessageHandlerTests {
delegate: NoOpScriptMessageHandler()
)

#expect(!handler.shouldAllowMessage(isMainFrame: false, host: "external.example.com"))
#expect(!handler.shouldAllowMessage(isMainFrame: false, host: "external.example.com", port: 443))
}
}

Expand Down
Loading