-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: report changes in supported protocols to ConnectionHandler
#3651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
f2801fb
9a28cd2
e536dea
7699a1e
450fc1e
a972b9c
f9cf33f
deb30f3
3c1bf5f
9e70729
f1328c5
a4fbcda
3e33108
572ed90
586394b
b329a09
a63af89
72510dd
8ffafdd
ea1a087
27bc507
58039d9
74dd94a
628b519
c7b5011
a02ca55
c347c8a
f3e5e71
b7fa7ef
2bd9d73
e90c40d
84979e4
dbfc7e7
f2d2c88
b41aeb8
bcd872b
021f1d4
fb096ad
e95c738
82642b8
eb66489
8c47bd6
bf9421e
a82343a
19cd9b9
6d3e9ee
df93a4e
c50bcfd
a799798
0260ad1
bf99654
46f4e96
ae7fc93
fe9a6e3
d7651b4
75681c1
bdfb04f
88362e5
3c8d326
77e4e5b
5a772e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,15 +32,15 @@ use libp2p_identity::PeerId; | |
| use libp2p_identity::PublicKey; | ||
| use libp2p_swarm::handler::{ | ||
| ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, | ||
| ProtocolsChange, | ||
| ProtocolSupport, ProtocolsChange, | ||
| }; | ||
| use libp2p_swarm::{ | ||
| ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr, KeepAlive, | ||
| NegotiatedSubstream, SubstreamProtocol, | ||
| }; | ||
| use log::warn; | ||
| use smallvec::SmallVec; | ||
| use std::collections::VecDeque; | ||
| use std::collections::{HashSet, VecDeque}; | ||
| use std::{io, pin::Pin, task::Context, task::Poll, time::Duration}; | ||
|
|
||
| /// Protocol handler for sending and receiving identification requests. | ||
|
|
@@ -85,7 +85,8 @@ pub struct Handler { | |
| /// Address observed by or for the remote. | ||
| observed_addr: Multiaddr, | ||
|
|
||
| local_supported_protocols: Vec<String>, | ||
| local_supported_protocols: HashSet<String>, | ||
| remote_supported_protocols: HashSet<String>, | ||
| } | ||
|
|
||
| /// An event from `Behaviour` with the information requested by the `Handler`. | ||
|
|
@@ -138,7 +139,8 @@ impl Handler { | |
| protocol_version, | ||
| agent_version, | ||
| observed_addr, | ||
| local_supported_protocols: vec![], | ||
| local_supported_protocols: HashSet::new(), | ||
| remote_supported_protocols: HashSet::new(), | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -187,10 +189,34 @@ impl Handler { | |
| ) { | ||
| match output { | ||
| future::Either::Left(remote_info) => { | ||
| self.events | ||
| .push(ConnectionHandlerEvent::ReportRemoteProtocols { | ||
| protocols: remote_info.protocols.clone(), | ||
| }); | ||
| let new_remote_protocols = HashSet::from_iter(remote_info.protocols.clone()); | ||
|
|
||
| let remote_added_protocols = new_remote_protocols | ||
| .difference(&self.remote_supported_protocols) | ||
| .cloned() | ||
| .collect::<HashSet<_>>(); | ||
| let remote_removed_protocols = self | ||
| .remote_supported_protocols | ||
| .difference(&new_remote_protocols) | ||
| .cloned() | ||
| .collect::<HashSet<_>>(); | ||
|
|
||
| if !remote_added_protocols.is_empty() { | ||
| self.events | ||
| .push(ConnectionHandlerEvent::ReportRemoteProtocols( | ||
| ProtocolSupport::Added(remote_added_protocols), | ||
| )); | ||
| } | ||
|
|
||
| if !remote_removed_protocols.is_empty() { | ||
| self.events | ||
| .push(ConnectionHandlerEvent::ReportRemoteProtocols( | ||
| ProtocolSupport::Removed(remote_removed_protocols), | ||
| )); | ||
| } | ||
|
|
||
| self.remote_supported_protocols = new_remote_protocols; | ||
|
|
||
| self.events | ||
| .push(ConnectionHandlerEvent::Custom(Event::Identified( | ||
| remote_info, | ||
|
|
@@ -251,7 +277,7 @@ impl ConnectionHandler for Handler { | |
| protocol_version: self.protocol_version.clone(), | ||
| agent_version: self.agent_version.clone(), | ||
| listen_addrs, | ||
| protocols: self.local_supported_protocols.clone(), | ||
| protocols: Vec::from_iter(self.local_supported_protocols.clone()), | ||
| observed_addr: self.observed_addr.clone(), | ||
| }; | ||
|
|
||
|
|
@@ -311,10 +337,11 @@ impl ConnectionHandler for Handler { | |
| self.inbound_identify_push.take(); | ||
|
|
||
| if let Ok(info) = res { | ||
| self.events | ||
| .push(ConnectionHandlerEvent::ReportRemoteProtocols { | ||
| protocols: info.protocols.clone(), | ||
| }); | ||
| // TODO: report new protocols | ||
| // self.events | ||
| // .push(ConnectionHandlerEvent::ReportRemoteProtocols { | ||
| // protocols: info.protocols.clone(), | ||
| // }); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Open TODO item. |
||
| return Poll::Ready(ConnectionHandlerEvent::Custom(Event::Identified(info))); | ||
| } | ||
| } | ||
|
|
@@ -353,8 +380,13 @@ impl ConnectionHandler for Handler { | |
| self.on_dial_upgrade_error(dial_upgrade_error) | ||
| } | ||
| ConnectionEvent::AddressChange(_) | ConnectionEvent::ListenUpgradeError(_) => {} | ||
| ConnectionEvent::LocalProtocolsChange(ProtocolsChange { protocols }) => { | ||
| self.local_supported_protocols = protocols.to_vec(); | ||
| ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => { | ||
| self.local_supported_protocols.extend(added.cloned()); | ||
| } | ||
| ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => { | ||
| for p in removed { | ||
| self.local_supported_protocols.remove(p); | ||
| } | ||
| } | ||
| ConnectionEvent::RemoteProtocolsChange(_) => {} | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,8 @@ pub use error::{ | |
|
|
||
| use crate::handler::{ | ||
| AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError, FullyNegotiatedInbound, | ||
| FullyNegotiatedOutbound, ListenUpgradeError, ProtocolsChange, | ||
| FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport, ProtocolsAdded, ProtocolsChange, | ||
| ProtocolsRemoved, | ||
| }; | ||
| use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper, UpgradeInfoSend}; | ||
| use crate::{ConnectionHandlerEvent, ConnectionHandlerUpgrErr, KeepAlive, SubstreamProtocol}; | ||
|
|
@@ -45,6 +46,7 @@ use libp2p_core::upgrade::{InboundUpgradeApply, OutboundUpgradeApply}; | |
| use libp2p_core::Endpoint; | ||
| use libp2p_core::{upgrade, ProtocolName as _, UpgradeError}; | ||
| use libp2p_identity::PeerId; | ||
| use std::collections::HashSet; | ||
| use std::future::Future; | ||
| use std::sync::atomic::{AtomicUsize, Ordering}; | ||
| use std::task::Waker; | ||
|
|
@@ -146,7 +148,7 @@ where | |
| SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>, | ||
| >, | ||
|
|
||
| supported_protocols: Vec<String>, | ||
| supported_protocols: HashSet<String>, | ||
| } | ||
|
|
||
| impl<THandler> fmt::Debug for Connection<THandler> | ||
|
|
@@ -184,7 +186,7 @@ where | |
| substream_upgrade_protocol_override, | ||
| max_negotiating_inbound_streams, | ||
| requested_substreams: Default::default(), | ||
| supported_protocols: vec![], | ||
| supported_protocols: HashSet::new(), | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -217,6 +219,21 @@ where | |
| supported_protocols, | ||
| } = self.get_mut(); | ||
|
|
||
| let protocol = handler.listen_protocol(); | ||
|
|
||
| let new_protocols = protocol | ||
| .upgrade() | ||
| .protocol_info() | ||
| .filter_map(|i| String::from_utf8(i.protocol_name().to_vec()).ok()) | ||
| .collect::<HashSet<_>>(); | ||
|
|
||
| handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( | ||
| ProtocolsChange::Added(ProtocolsAdded { | ||
| protocols: new_protocols.difference(&HashSet::new()).peekable(), | ||
| }), | ||
| )); | ||
| *supported_protocols = new_protocols; | ||
|
thomaseizinger marked this conversation as resolved.
Outdated
|
||
|
|
||
| loop { | ||
| match requested_substreams.poll_next_unpin(cx) { | ||
| Poll::Ready(Some(Ok(()))) => continue, | ||
|
|
@@ -248,11 +265,23 @@ where | |
| Poll::Ready(ConnectionHandlerEvent::Close(err)) => { | ||
| return Poll::Ready(Err(ConnectionError::Handler(err))); | ||
| } | ||
| Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols { protocols }) => { | ||
| Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols( | ||
| ProtocolSupport::Added(protocols), | ||
| )) => { | ||
| handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange( | ||
| ProtocolsChange { | ||
| protocols: &protocols, | ||
| }, | ||
| ProtocolsChange::Added(ProtocolsAdded { | ||
| protocols: protocols.difference(&HashSet::new()).peekable(), // This is a bit of a hack to use the same type internally in `ProtocolsAdded`. | ||
| }), | ||
| )); | ||
| continue; | ||
| } | ||
|
Comment on lines
+263
to
+281
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we could cache the remote protocols and then only pass a Just an idea. Happy to go with at it is here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is a good idea, it would likely remove a lot of unnecessary invocations of the various handlers.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've pushed a first version, will likely iterate on that further. |
||
| Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols( | ||
| ProtocolSupport::Removed(protocols), | ||
| )) => { | ||
| handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange( | ||
| ProtocolsChange::Removed(ProtocolsRemoved { | ||
| protocols: protocols.difference(&HashSet::new()).peekable(), // This is a bit of a hack to use the same type internally in `ProtocolsRemoved`. | ||
| }), | ||
| )); | ||
| continue; | ||
| } | ||
|
thomaseizinger marked this conversation as resolved.
|
||
|
|
@@ -367,21 +396,33 @@ where | |
| Poll::Ready(substream) => { | ||
| let protocol = handler.listen_protocol(); | ||
|
|
||
| let mut new_protocols = protocol | ||
| let new_protocols = protocol | ||
| .upgrade() | ||
| .protocol_info() | ||
| .filter_map(|i| String::from_utf8(i.protocol_name().to_vec()).ok()) | ||
| .collect::<Vec<_>>(); | ||
|
|
||
| new_protocols.sort(); | ||
|
|
||
| if supported_protocols != &new_protocols { | ||
| handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( | ||
| ProtocolsChange { | ||
| protocols: &new_protocols, | ||
| }, | ||
| )); | ||
| *supported_protocols = new_protocols; | ||
| .collect::<HashSet<_>>(); | ||
|
|
||
| if &new_protocols != supported_protocols { | ||
| let mut added_protocols = | ||
| new_protocols.difference(supported_protocols).peekable(); | ||
| let mut removed_protocols = | ||
| supported_protocols.difference(&new_protocols).peekable(); | ||
|
|
||
| if added_protocols.peek().is_some() { | ||
| handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( | ||
| ProtocolsChange::Removed(ProtocolsRemoved { | ||
| protocols: added_protocols, | ||
| }), | ||
| )); | ||
| } | ||
|
|
||
| if removed_protocols.peek().is_some() { | ||
| handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( | ||
| ProtocolsChange::Removed(ProtocolsRemoved { | ||
| protocols: removed_protocols, | ||
| }), | ||
| )); | ||
| } | ||
| } | ||
|
|
||
| negotiating_in.push(SubstreamUpgrade::new_inbound(substream, protocol)); | ||
|
|
@@ -956,8 +997,16 @@ mod tests { | |
| Self::OutboundOpenInfo, | ||
| >, | ||
| ) { | ||
| if let ConnectionEvent::LocalProtocolsChange(ProtocolsChange { protocols }) = event { | ||
| self.reported_protocols = protocols.to_vec(); | ||
| match event { | ||
| ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => { | ||
| self.reported_protocols.extend(added.cloned()); | ||
| } | ||
| ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => { | ||
| for protocol in removed { | ||
| self.reported_protocols.retain(|p| p != protocol); | ||
| } | ||
| } | ||
| _ => {} | ||
| } | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.