# Copyright (c) 2007-2010 Twisted Matrix Laboratories. # See LICENSE for details """ This module tests twisted.conch.ssh.connection. """ import struct from twisted.conch.ssh import channel from twisted.conch.test import test_userauth from twisted.python.reflect import requireModule from twisted.trial import unittest cryptography = requireModule("cryptography") from twisted.conch import error if cryptography: from twisted.conch.ssh import common, connection else: class connection: # type: ignore[no-redef] class SSHConnection: pass class TestChannel(channel.SSHChannel): """ A mocked-up version of twisted.conch.ssh.channel.SSHChannel. @ivar gotOpen: True if channelOpen has been called. @type gotOpen: L{bool} @ivar specificData: the specific channel open data passed to channelOpen. @type specificData: L{bytes} @ivar openFailureReason: the reason passed to openFailed. @type openFailed: C{error.ConchError} @ivar inBuffer: a C{list} of strings received by the channel. @type inBuffer: C{list} @ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by the channel. @type extBuffer: C{list} @ivar numberRequests: the number of requests that have been made to this channel. @type numberRequests: L{int} @ivar gotEOF: True if the other side sent EOF. @type gotEOF: L{bool} @ivar gotOneClose: True if the other side closed the connection. @type gotOneClose: L{bool} @ivar gotClosed: True if the channel is closed. @type gotClosed: L{bool} """ name = b"TestChannel" gotOpen = False gotClosed = False def logPrefix(self): return "TestChannel %i" % self.id def channelOpen(self, specificData): """ The channel is open. Set up the instance variables. """ self.gotOpen = True self.specificData = specificData self.inBuffer = [] self.extBuffer = [] self.numberRequests = 0 self.gotEOF = False self.gotOneClose = False self.gotClosed = False def openFailed(self, reason): """ Opening the channel failed. Store the reason why. """ self.openFailureReason = reason def request_test(self, data): """ A test request. Return True if data is 'data'. @type data: L{bytes} """ self.numberRequests += 1 return data == b"data" def dataReceived(self, data): """ Data was received. Store it in the buffer. """ self.inBuffer.append(data) def extReceived(self, code, data): """ Extended data was received. Store it in the buffer. """ self.extBuffer.append((code, data)) def eofReceived(self): """ EOF was received. Remember it. """ self.gotEOF = True def closeReceived(self): """ Close was received. Remember it. """ self.gotOneClose = True def closed(self): """ The channel is closed. Rembember it. """ self.gotClosed = True class TestAvatar: """ A mocked-up version of twisted.conch.avatar.ConchUser """ _ARGS_ERROR_CODE = 123 def lookupChannel(self, channelType, windowSize, maxPacket, data): """ The server wants us to return a channel. If the requested channel is our TestChannel, return it, otherwise return None. """ if channelType == TestChannel.name: return TestChannel( remoteWindow=windowSize, remoteMaxPacket=maxPacket, data=data, avatar=self, ) elif channelType == b"conch-error-args": # Raise a ConchError with backwards arguments to make sure the # connection fixes it for us. This case should be deprecated and # deleted eventually, but only after all of Conch gets the argument # order right. raise error.ConchError(self._ARGS_ERROR_CODE, "error args in wrong order") def gotGlobalRequest(self, requestType, data): """ The client has made a global request. If the global request is 'TestGlobal', return True. If the global request is 'TestData', return True and the request-specific data we received. Otherwise, return False. """ if requestType == b"TestGlobal": return True elif requestType == b"TestData": return True, data else: return False class TestConnection(connection.SSHConnection): """ A subclass of SSHConnection for testing. @ivar channel: the current channel. @type channel. C{TestChannel} """ if not cryptography: skip = "Cannot run without cryptography" def logPrefix(self): return "TestConnection" def global_TestGlobal(self, data): """ The other side made the 'TestGlobal' global request. Return True. """ return True def global_Test_Data(self, data): """ The other side made the 'Test-Data' global request. Return True and the data we received. """ return True, data def channel_TestChannel(self, windowSize, maxPacket, data): """ The other side is requesting the TestChannel. Create a C{TestChannel} instance, store it, and return it. """ self.channel = TestChannel( remoteWindow=windowSize, remoteMaxPacket=maxPacket, data=data ) return self.channel def channel_ErrorChannel(self, windowSize, maxPacket, data): """ The other side is requesting the ErrorChannel. Raise an exception. """ raise AssertionError("no such thing") class ConnectionTests(unittest.TestCase): if not cryptography: skip = "Cannot run without cryptography" def setUp(self): self.transport = test_userauth.FakeTransport(None) self.transport.avatar = TestAvatar() self.conn = TestConnection() self.conn.transport = self.transport self.conn.serviceStarted() def _openChannel(self, channel): """ Open the channel with the default connection. """ self.conn.openChannel(channel) self.transport.packets = self.transport.packets[:-1] self.conn.ssh_CHANNEL_OPEN_CONFIRMATION( struct.pack(">2L", channel.id, 255) + b"\x00\x02\x00\x00\x00\x00\x80\x00" ) def tearDown(self): self.conn.serviceStopped() def test_linkAvatar(self): """ Test that the connection links itself to the avatar in the transport. """ self.assertIs(self.transport.avatar.conn, self.conn) def test_serviceStopped(self): """ Test that serviceStopped() closes any open channels. """ channel1 = TestChannel() channel2 = TestChannel() self.conn.openChannel(channel1) self.conn.openChannel(channel2) self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b"\x00\x00\x00\x00" * 4) self.assertTrue(channel1.gotOpen) self.assertFalse(channel1.gotClosed) self.assertFalse(channel2.gotOpen) self.assertFalse(channel2.gotClosed) self.conn.serviceStopped() self.assertTrue(channel1.gotClosed) self.assertFalse(channel2.gotOpen) self.assertFalse(channel2.gotClosed) from twisted.internet.error import ConnectionLost self.assertIsInstance(channel2.openFailureReason, ConnectionLost) def test_GLOBAL_REQUEST(self): """ Test that global request packets are dispatched to the global_* methods and the return values are translated into success or failure messages. """ self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestGlobal") + b"\xff") self.assertEqual( self.transport.packets, [(connection.MSG_REQUEST_SUCCESS, b"")] ) self.transport.packets = [] self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestData") + b"\xff" + b"test data") self.assertEqual( self.transport.packets, [(connection.MSG_REQUEST_SUCCESS, b"test data")] ) self.transport.packets = [] self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestBad") + b"\xff") self.assertEqual( self.transport.packets, [(connection.MSG_REQUEST_FAILURE, b"")] ) self.transport.packets = [] self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestGlobal") + b"\x00") self.assertEqual(self.transport.packets, []) def test_REQUEST_SUCCESS(self): """ Test that global request success packets cause the Deferred to be called back. """ d = self.conn.sendGlobalRequest(b"request", b"data", True) self.conn.ssh_REQUEST_SUCCESS(b"data") def check(data): self.assertEqual(data, b"data") d.addCallback(check) d.addErrback(self.fail) return d def test_REQUEST_FAILURE(self): """ Test that global request failure packets cause the Deferred to be erred back. """ d = self.conn.sendGlobalRequest(b"request", b"data", True) self.conn.ssh_REQUEST_FAILURE(b"data") def check(f): self.assertEqual(f.value.data, b"data") d.addCallback(self.fail) d.addErrback(check) return d def test_CHANNEL_OPEN(self): """ Test that open channel packets cause a channel to be created and opened or a failure message to be returned. """ del self.transport.avatar self.conn.ssh_CHANNEL_OPEN(common.NS(b"TestChannel") + b"\x00\x00\x00\x01" * 4) self.assertTrue(self.conn.channel.gotOpen) self.assertEqual(self.conn.channel.conn, self.conn) self.assertEqual(self.conn.channel.data, b"\x00\x00\x00\x01") self.assertEqual(self.conn.channel.specificData, b"\x00\x00\x00\x01") self.assertEqual(self.conn.channel.remoteWindowLeft, 1) self.assertEqual(self.conn.channel.remoteMaxPacket, 1) self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_OPEN_CONFIRMATION, b"\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00" b"\x00\x00\x80\x00", ) ], ) self.transport.packets = [] self.conn.ssh_CHANNEL_OPEN(common.NS(b"BadChannel") + b"\x00\x00\x00\x02" * 4) self.flushLoggedErrors() self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_OPEN_FAILURE, b"\x00\x00\x00\x02\x00\x00\x00\x03" + common.NS(b"unknown channel") + common.NS(b""), ) ], ) self.transport.packets = [] self.conn.ssh_CHANNEL_OPEN(common.NS(b"ErrorChannel") + b"\x00\x00\x00\x02" * 4) self.flushLoggedErrors() self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_OPEN_FAILURE, b"\x00\x00\x00\x02\x00\x00\x00\x02" + common.NS(b"unknown failure") + common.NS(b""), ) ], ) def _lookupChannelErrorTest(self, code): """ Deliver a request for a channel open which will result in an exception being raised during channel lookup. Assert that an error response is delivered as a result. """ self.transport.avatar._ARGS_ERROR_CODE = code self.conn.ssh_CHANNEL_OPEN( common.NS(b"conch-error-args") + b"\x00\x00\x00\x01" * 4 ) errors = self.flushLoggedErrors(error.ConchError) self.assertEqual(len(errors), 1, f"Expected one error, got: {errors!r}") self.assertEqual(errors[0].value.args, (123, "error args in wrong order")) self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_OPEN_FAILURE, # The response includes some bytes which identifying the # associated request, as well as the error code (7b in hex) and # the error message. b"\x00\x00\x00\x01\x00\x00\x00\x7b" + common.NS(b"error args in wrong order") + common.NS(b""), ) ], ) def test_lookupChannelError(self): """ If a C{lookupChannel} implementation raises L{error.ConchError} with the arguments in the wrong order, a C{MSG_CHANNEL_OPEN} failure is still sent in response to the message. This is a temporary work-around until L{error.ConchError} is given better attributes and all of the Conch code starts constructing instances of it properly. Eventually this functionality should be deprecated and then removed. """ self._lookupChannelErrorTest(123) def test_CHANNEL_OPEN_CONFIRMATION(self): """ Test that channel open confirmation packets cause the channel to be notified that it's open. """ channel = TestChannel() self.conn.openChannel(channel) self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b"\x00\x00\x00\x00" * 5) self.assertEqual(channel.remoteWindowLeft, 0) self.assertEqual(channel.remoteMaxPacket, 0) self.assertEqual(channel.specificData, b"\x00\x00\x00\x00") self.assertEqual(self.conn.channelsToRemoteChannel[channel], 0) self.assertEqual(self.conn.localToRemoteChannel[0], 0) def test_CHANNEL_OPEN_FAILURE(self): """ Test that channel open failure packets cause the channel to be notified that its opening failed. """ channel = TestChannel() self.conn.openChannel(channel) self.conn.ssh_CHANNEL_OPEN_FAILURE( b"\x00\x00\x00\x00\x00\x00\x00" b"\x01" + common.NS(b"failure!") ) self.assertEqual(channel.openFailureReason.args, (b"failure!", 1)) self.assertIsNone(self.conn.channels.get(channel)) def test_CHANNEL_WINDOW_ADJUST(self): """ Test that channel window adjust messages add bytes to the channel window. """ channel = TestChannel() self._openChannel(channel) oldWindowSize = channel.remoteWindowLeft self.conn.ssh_CHANNEL_WINDOW_ADJUST(b"\x00\x00\x00\x00\x00\x00\x00" b"\x01") self.assertEqual(channel.remoteWindowLeft, oldWindowSize + 1) def test_CHANNEL_DATA(self): """ Test that channel data messages are passed up to the channel, or cause the channel to be closed if the data is too large. """ channel = TestChannel(localWindow=6, localMaxPacket=5) self._openChannel(channel) self.conn.ssh_CHANNEL_DATA(b"\x00\x00\x00\x00" + common.NS(b"data")) self.assertEqual(channel.inBuffer, [b"data"]) self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_WINDOW_ADJUST, b"\x00\x00\x00\xff" b"\x00\x00\x00\x04", ) ], ) self.transport.packets = [] longData = b"a" * (channel.localWindowLeft + 1) self.conn.ssh_CHANNEL_DATA(b"\x00\x00\x00\x00" + common.NS(longData)) self.assertEqual(channel.inBuffer, [b"data"]) self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")], ) channel = TestChannel() self._openChannel(channel) bigData = b"a" * (channel.localMaxPacket + 1) self.transport.packets = [] self.conn.ssh_CHANNEL_DATA(b"\x00\x00\x00\x01" + common.NS(bigData)) self.assertEqual(channel.inBuffer, []) self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")], ) def test_CHANNEL_EXTENDED_DATA(self): """ Test that channel extended data messages are passed up to the channel, or cause the channel to be closed if they're too big. """ channel = TestChannel(localWindow=6, localMaxPacket=5) self._openChannel(channel) self.conn.ssh_CHANNEL_EXTENDED_DATA( b"\x00\x00\x00\x00\x00\x00\x00" b"\x00" + common.NS(b"data") ) self.assertEqual(channel.extBuffer, [(0, b"data")]) self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_WINDOW_ADJUST, b"\x00\x00\x00\xff" b"\x00\x00\x00\x04", ) ], ) self.transport.packets = [] longData = b"a" * (channel.localWindowLeft + 1) self.conn.ssh_CHANNEL_EXTENDED_DATA( b"\x00\x00\x00\x00\x00\x00\x00" b"\x00" + common.NS(longData) ) self.assertEqual(channel.extBuffer, [(0, b"data")]) self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")], ) channel = TestChannel() self._openChannel(channel) bigData = b"a" * (channel.localMaxPacket + 1) self.transport.packets = [] self.conn.ssh_CHANNEL_EXTENDED_DATA( b"\x00\x00\x00\x01\x00\x00\x00" b"\x00" + common.NS(bigData) ) self.assertEqual(channel.extBuffer, []) self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")], ) def test_CHANNEL_EOF(self): """ Test that channel eof messages are passed up to the channel. """ channel = TestChannel() self._openChannel(channel) self.conn.ssh_CHANNEL_EOF(b"\x00\x00\x00\x00") self.assertTrue(channel.gotEOF) def test_CHANNEL_CLOSE(self): """ Test that channel close messages are passed up to the channel. Also, test that channel.close() is called if both sides are closed when this message is received. """ channel = TestChannel() self._openChannel(channel) self.assertTrue(channel.gotOpen) self.assertFalse(channel.gotOneClose) self.assertFalse(channel.gotClosed) self.conn.sendClose(channel) self.conn.ssh_CHANNEL_CLOSE(b"\x00\x00\x00\x00") self.assertTrue(channel.gotOneClose) self.assertTrue(channel.gotClosed) def test_CHANNEL_REQUEST_success(self): """ Test that channel requests that succeed send MSG_CHANNEL_SUCCESS. """ channel = TestChannel() self._openChannel(channel) self.conn.ssh_CHANNEL_REQUEST( b"\x00\x00\x00\x00" + common.NS(b"test") + b"\x00" ) self.assertEqual(channel.numberRequests, 1) d = self.conn.ssh_CHANNEL_REQUEST( b"\x00\x00\x00\x00" + common.NS(b"test") + b"\xff" + b"data" ) def check(result): self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_SUCCESS, b"\x00\x00\x00\xff")], ) d.addCallback(check) return d def test_CHANNEL_REQUEST_failure(self): """ Test that channel requests that fail send MSG_CHANNEL_FAILURE. """ channel = TestChannel() self._openChannel(channel) d = self.conn.ssh_CHANNEL_REQUEST( b"\x00\x00\x00\x00" + common.NS(b"test") + b"\xff" ) def check(result): self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_FAILURE, b"\x00\x00\x00\xff")], ) d.addCallback(self.fail) d.addErrback(check) return d def test_CHANNEL_REQUEST_SUCCESS(self): """ Test that channel request success messages cause the Deferred to be called back. """ channel = TestChannel() self._openChannel(channel) d = self.conn.sendRequest(channel, b"test", b"data", True) self.conn.ssh_CHANNEL_SUCCESS(b"\x00\x00\x00\x00") def check(result): self.assertTrue(result) return d def test_CHANNEL_REQUEST_FAILURE(self): """ Test that channel request failure messages cause the Deferred to be erred back. """ channel = TestChannel() self._openChannel(channel) d = self.conn.sendRequest(channel, b"test", b"", True) self.conn.ssh_CHANNEL_FAILURE(b"\x00\x00\x00\x00") def check(result): self.assertEqual(result.value.value, "channel request failed") d.addCallback(self.fail) d.addErrback(check) return d def test_sendGlobalRequest(self): """ Test that global request messages are sent in the right format. """ d = self.conn.sendGlobalRequest(b"wantReply", b"data", True) # must be added to prevent errbacking during teardown d.addErrback(lambda failure: None) self.conn.sendGlobalRequest(b"noReply", b"", False) self.assertEqual( self.transport.packets, [ (connection.MSG_GLOBAL_REQUEST, common.NS(b"wantReply") + b"\xffdata"), (connection.MSG_GLOBAL_REQUEST, common.NS(b"noReply") + b"\x00"), ], ) self.assertEqual(self.conn.deferreds, {"global": [d]}) def test_openChannel(self): """ Test that open channel messages are sent in the right format. """ channel = TestChannel() self.conn.openChannel(channel, b"aaaa") self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_OPEN, common.NS(b"TestChannel") + b"\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa", ) ], ) self.assertEqual(channel.id, 0) self.assertEqual(self.conn.localChannelID, 1) def test_sendRequest(self): """ Test that channel request messages are sent in the right format. """ channel = TestChannel() self._openChannel(channel) d = self.conn.sendRequest(channel, b"test", b"test", True) # needed to prevent errbacks during teardown. d.addErrback(lambda failure: None) self.conn.sendRequest(channel, b"test2", b"", False) channel.localClosed = True # emulate sending a close message self.conn.sendRequest(channel, b"test3", b"", True) self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_REQUEST, b"\x00\x00\x00\xff" + common.NS(b"test") + b"\x01test", ), ( connection.MSG_CHANNEL_REQUEST, b"\x00\x00\x00\xff" + common.NS(b"test2") + b"\x00", ), ], ) self.assertEqual(self.conn.deferreds[0], [d]) def test_adjustWindow(self): """ Test that channel window adjust messages cause bytes to be added to the window. """ channel = TestChannel(localWindow=5) self._openChannel(channel) channel.localWindowLeft = 0 self.conn.adjustWindow(channel, 1) self.assertEqual(channel.localWindowLeft, 1) channel.localClosed = True self.conn.adjustWindow(channel, 2) self.assertEqual(channel.localWindowLeft, 1) self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_WINDOW_ADJUST, b"\x00\x00\x00\xff" b"\x00\x00\x00\x01", ) ], ) def test_sendData(self): """ Test that channel data messages are sent in the right format. """ channel = TestChannel() self._openChannel(channel) self.conn.sendData(channel, b"a") channel.localClosed = True self.conn.sendData(channel, b"b") self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_DATA, b"\x00\x00\x00\xff" + common.NS(b"a"))], ) def test_sendExtendedData(self): """ Test that channel extended data messages are sent in the right format. """ channel = TestChannel() self._openChannel(channel) self.conn.sendExtendedData(channel, 1, b"test") channel.localClosed = True self.conn.sendExtendedData(channel, 2, b"test2") self.assertEqual( self.transport.packets, [ ( connection.MSG_CHANNEL_EXTENDED_DATA, b"\x00\x00\x00\xff" + b"\x00\x00\x00\x01" + common.NS(b"test"), ) ], ) def test_sendEOF(self): """ Test that channel EOF messages are sent in the right format. """ channel = TestChannel() self._openChannel(channel) self.conn.sendEOF(channel) self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_EOF, b"\x00\x00\x00\xff")] ) channel.localClosed = True self.conn.sendEOF(channel) self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_EOF, b"\x00\x00\x00\xff")] ) def test_sendClose(self): """ Test that channel close messages are sent in the right format. """ channel = TestChannel() self._openChannel(channel) self.conn.sendClose(channel) self.assertTrue(channel.localClosed) self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")], ) self.conn.sendClose(channel) self.assertEqual( self.transport.packets, [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")], ) channel2 = TestChannel() self._openChannel(channel2) self.assertTrue(channel2.gotOpen) self.assertFalse(channel2.gotClosed) channel2.remoteClosed = True self.conn.sendClose(channel2) self.assertTrue(channel2.gotClosed) def test_getChannelWithAvatar(self): """ Test that getChannel dispatches to the avatar when an avatar is present. Correct functioning without the avatar is verified in test_CHANNEL_OPEN. """ channel = self.conn.getChannel(b"TestChannel", 50, 30, b"data") self.assertEqual(channel.data, b"data") self.assertEqual(channel.remoteWindowLeft, 50) self.assertEqual(channel.remoteMaxPacket, 30) self.assertRaises( error.ConchError, self.conn.getChannel, b"BadChannel", 50, 30, b"data" ) def test_gotGlobalRequestWithoutAvatar(self): """ Test that gotGlobalRequests dispatches to global_* without an avatar. """ del self.transport.avatar self.assertTrue(self.conn.gotGlobalRequest(b"TestGlobal", b"data")) self.assertEqual( self.conn.gotGlobalRequest(b"Test-Data", b"data"), (True, b"data") ) self.assertFalse(self.conn.gotGlobalRequest(b"BadGlobal", b"data")) def test_channelClosedCausesLeftoverChannelDeferredsToErrback(self): """ Whenever an SSH channel gets closed any Deferred that was returned by a sendRequest() on its parent connection must be errbacked. """ channel = TestChannel() self._openChannel(channel) d = self.conn.sendRequest(channel, b"dummyrequest", b"dummydata", wantReply=1) d = self.assertFailure(d, error.ConchError) self.conn.channelClosed(channel) return d class CleanConnectionShutdownTests(unittest.TestCase): """ Check whether correct cleanup is performed on connection shutdown. """ if not cryptography: skip = "Cannot run without cryptography" def setUp(self): self.transport = test_userauth.FakeTransport(None) self.transport.avatar = TestAvatar() self.conn = TestConnection() self.conn.transport = self.transport def test_serviceStoppedCausesLeftoverGlobalDeferredsToErrback(self): """ Once the service is stopped any leftover global deferred returned by a sendGlobalRequest() call must be errbacked. """ self.conn.serviceStarted() d = self.conn.sendGlobalRequest(b"dummyrequest", b"dummydata", wantReply=1) d = self.assertFailure(d, error.ConchError) self.conn.serviceStopped() return d