Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ final class SafeScriptMessageHandler: NSObject, WKScriptMessageHandler {
_ userContentController: WKUserContentController,
didReceive message: WKScriptMessage
) {
// Only the top-level document on an allowed server host may talk to the native bridge.
// Only the top-level document on an allowed server origin may talk to the native bridge.
guard shouldAllowMessage(
isMainFrame: message.frameInfo.isMainFrame,
host: message.frameInfo.securityOrigin.host
scheme: message.frameInfo.securityOrigin.protocol,
host: message.frameInfo.securityOrigin.host,
port: message.frameInfo.securityOrigin.port
) else {
return
}
Expand All @@ -28,17 +30,30 @@ final class SafeScriptMessageHandler: NSObject, WKScriptMessageHandler {
)
}

func shouldAllowMessage(isMainFrame: Bool, host: String) -> Bool {
isMainFrame && allowedHosts.contains(host)
func shouldAllowMessage(isMainFrame: Bool, scheme: String, host: String, port: Int) -> Bool {
isMainFrame && allowedOrigins.contains(originKey(scheme: scheme, host: host, port: port))
}

private var allowedHosts: Set<String> {
private var allowedOrigins: Set<String> {
let urls = [
server.info.connection.address(for: .internal),
server.info.connection.address(for: .external),
server.info.connection.address(for: .remoteUI),
]

return Set(urls.compactMap { $0?.host })
return Set(urls.compactMap(originKey(url:)))
}

private func originKey(url: URL?) -> String? {
guard let url, let scheme = url.scheme?.lowercased(), let host = url.host,
let port = url.portWithFallback else {
return nil
}

return originKey(scheme: scheme, host: host, port: port)
}

private func originKey(scheme: String, host: String, port: Int) -> String {
"\(scheme.lowercased())://\(host.lowercased()):\(port)"
}
}
2 changes: 1 addition & 1 deletion Sources/Shared/Common/Extensions/URL+Extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public extension URL {
}

// port will be removed if 80 or 443 by WKWebView, so we provide defaults for comparison
internal var portWithFallback: Int? {
var portWithFallback: Int? {
if let port {
return port
}
Expand Down
26 changes: 19 additions & 7 deletions Tests/App/WebView/SafeScriptMessageHandlerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,33 @@ import Testing
import WebKit

struct SafeScriptMessageHandlerTests {
@Test func allowsMainFrameMessageFromConfiguredServerHost() {
@Test func allowsMainFrameMessageFromConfiguredServerOrigin() {
ServerFixture.reset()
let handler = SafeScriptMessageHandler(
server: ServerFixture.withRemoteConnection,
delegate: NoOpScriptMessageHandler()
)

#expect(handler.shouldAllowMessage(isMainFrame: true, host: "external.example.com"))
#expect(handler.shouldAllowMessage(isMainFrame: true, host: "internal.example.com"))
#expect(handler.shouldAllowMessage(isMainFrame: true, host: "ui.nabu.casa"))
#expect(handler.shouldAllowMessage(isMainFrame: true, scheme: "https", host: "external.example.com", port: 443))
#expect(handler.shouldAllowMessage(isMainFrame: true, scheme: "http", host: "internal.example.com", port: 80))
#expect(handler.shouldAllowMessage(isMainFrame: true, scheme: "https", host: "ui.nabu.casa", port: 443))
}

@Test func rejectsMessageFromHostOutsideConfiguredServerHosts() {
@Test func rejectsMessageFromOriginOutsideConfiguredServerOrigins() {
ServerFixture.reset()
let handler = SafeScriptMessageHandler(
server: ServerFixture.withRemoteConnection,
delegate: NoOpScriptMessageHandler()
)

#expect(!handler.shouldAllowMessage(isMainFrame: true, host: "evil.example.com"))
#expect(!handler.shouldAllowMessage(isMainFrame: true, scheme: "https", host: "evil.example.com", port: 443))
#expect(!handler.shouldAllowMessage(
isMainFrame: true,
scheme: "https",
host: "external.example.com",
port: 8123
))
#expect(!handler.shouldAllowMessage(isMainFrame: true, scheme: "http", host: "external.example.com", port: 443))
}

@Test func rejectsIframeMessageEvenWhenHostIsAllowed() {
Expand All @@ -33,7 +40,12 @@ struct SafeScriptMessageHandlerTests {
delegate: NoOpScriptMessageHandler()
)

#expect(!handler.shouldAllowMessage(isMainFrame: false, host: "external.example.com"))
#expect(!handler.shouldAllowMessage(
isMainFrame: false,
scheme: "https",
host: "external.example.com",
port: 443
))
}
}

Expand Down
Loading