Skip to content
Merged
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Added

Changed
~~~~~~~

* Fix ssh zombies when using ProxyCommand from ssh config #4881
* Install pack with the latest tag version if it exists when branch is not specialized.
(improvement) #4743
* Implement "continue" engine command to orquesta workflow. (improvement) #4740
Expand Down
32 changes: 32 additions & 0 deletions st2actions/tests/unit/test_paramiko_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,35 @@ def test_use_ssh_config_port_value_provided_in_the_config(self, mock_sshclient):

call_kwargs = mock_client.connect.call_args[1]
self.assertEqual(call_kwargs['port'], 9999)

@patch('paramiko.SSHClient', Mock)
@patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
MagicMock(return_value=False))

def test_socket_closed(self):
conn_params = {'hostname': 'dummy.host.org',
'username': 'ubuntu',
'password': 'pass',
'timeout': '600'}
ssh_client = ParamikoSSHClient(**conn_params)
ssh_client.connect()

# Make sure .close() doesn't actually call anything real
ssh_client.client = Mock()
ssh_client.sftp_client = None
ssh_client.bastion_client = None

ssh_client.socket = Mock()

# Make sure we havent called any close methods at this point
self.assertEqual(ssh_client.socket.process.kill.call_count, 0)
self.assertEqual(ssh_client.socket.process.poll.call_count, 0)

# Call the function that has changed
ssh_client.close()

# Make sure we have called kill and poll
self.assertEqual(ssh_client.socket.process.kill.call_count, 1)
self.assertEqual(ssh_client.socket.process.poll.call_count, 1)


11 changes: 9 additions & 2 deletions st2common/st2common/runners/paramiko_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None

self.bastion_client = None
self.bastion_socket = None
self.socket = None

def connect(self):
"""
Expand Down Expand Up @@ -455,6 +456,12 @@ def close(self):

self.client.close()

if self.socket:
self.logger.debug('Closing proxycommand socket connection')
# https://github.com/paramiko/paramiko/issues/789 Avoid zombie ssh processes
self.socket.process.kill()
self.socket.process.poll()

if self.sftp_client:
self.sftp_client.close()

Expand Down Expand Up @@ -698,8 +705,8 @@ def _connect(self, host, socket=None):
'_username': self.username, '_timeout': self.timeout}
self.logger.debug('Connecting to server', extra=extra)

socket = socket or ssh_config_file_info.get('sock', None)
if socket:
self.socket = socket or ssh_config_file_info.get('sock', None)
if self.socket:
conninfo['sock'] = socket

client = paramiko.SSHClient()
Expand Down