diff --git a/pika/connection.py b/pika/connection.py index 76f581eb2dc15dbf3059d932978c54fd29df4f37_cGlrYS9jb25uZWN0aW9uLnB5..0e70d49d33edc79c6f3cc7d1f2a0853248f92065_cGlrYS9jb25uZWN0aW9uLnB5 100644 --- a/pika/connection.py +++ b/pika/connection.py @@ -6,6 +6,7 @@ import logging import math import numbers +import os import platform import warnings import ssl @@ -944,8 +945,67 @@ self.socket_timeout = socket_timeout def _set_url_ssl_options(self, value): - """Deserialize and apply the corresponding query string arg""" - self.ssl_options = ast.literal_eval(value) + """Deserialize and apply the corresponding query string arg + + """ + opts = ast.literal_eval(value) + if opts is None: + if self.ssl_options is not None: + raise ValueError( + 'Specified ssl_options=None URL arg is inconsistent with ' + 'the specified https URL scheme.') + else: + # Note: this is the deprecated wrap_socket signature and info: + # + # Internally, function creates a SSLContext with protocol + # ssl_version and SSLContext.options set to cert_reqs. + # If parameters keyfile, certfile, ca_certs or ciphers are set, + # then the values are passed to SSLContext.load_cert_chain(), + # SSLContext.load_verify_locations(), and SSLContext.set_ciphers(). + # + # ssl.wrap_socket(sock, + # keyfile=None, + # certfile=None, + # server_side=False, # Not URL-supported + # cert_reqs=CERT_NONE, # Not URL-supported + # ssl_version=PROTOCOL_TLS, # Not URL-supported + # ca_certs=None, + # do_handshake_on_connect=True, # Not URL-supported + # suppress_ragged_eofs=True, # Not URL-supported + # ciphers=None + cxt = None + if 'ca_certs' in opts: + opt_ca_certs = opts['ca_certs'] + if os.path.isfile(opt_ca_certs): + cxt = ssl.create_default_context(cafile=opt_ca_certs) + elif os.path.isdir(opt_ca_certs): + cxt = ssl.create_default_context(capath=opt_ca_certs) + else: + LOGGER.warning('ca_certs is specified via ssl_options but ' + 'is neither a valid file nor directory: "%s"', + opt_ca_certs) + + if 'certfile' in opts: + if os.path.isfile(opts['certfile']): + keyfile = opts.get('keyfile') + password = opts.get('password') + cxt.load_cert_chain(opts['certfile'], keyfile, password) + else: + LOGGER.warning('certfile is specified via ssl_options but ' + 'is not a valid file: "%s"', + opts['certfile']) + + if 'ciphers' in opts: + opt_ciphers = opts['ciphers'] + if opt_ciphers is not None: + cxt.set_ciphers(opt_ciphers) + else: + LOGGER.warning('ciphers specified in ssl_options but ' + 'evaluates to None') + + server_hostname = opts.get('server_hostname') + self.ssl_options = pika.SSLOptions(context=cxt, + server_hostname=server_hostname) def _set_url_tcp_options(self, value): """Deserialize and apply the corresponding query string arg""" diff --git a/tests/unit/connection_parameters_tests.py b/tests/unit/connection_parameters_tests.py index 76f581eb2dc15dbf3059d932978c54fd29df4f37_dGVzdHMvdW5pdC9jb25uZWN0aW9uX3BhcmFtZXRlcnNfdGVzdHMucHk=..0e70d49d33edc79c6f3cc7d1f2a0853248f92065_dGVzdHMvdW5pdC9jb25uZWN0aW9uX3BhcmFtZXRlcnNfdGVzdHMucHk= 100644 --- a/tests/unit/connection_parameters_tests.py +++ b/tests/unit/connection_parameters_tests.py @@ -606,7 +606,12 @@ 'retry_delay': 3, 'socket_timeout': 100.5, 'ssl_options': { - 'ssl': 'options' + 'ca_certs': '/etc/ssl', + 'certfile': '/etc/certs/cert.pem', + 'keyfile': '/etc/certs/key.pem', + 'password': 'test123', + 'ciphers': None, + 'server_hostname': 'blah.blah.com' }, 'tcp_options': { 'TCP_USER_TIMEOUT': 1000, @@ -619,7 +624,7 @@ test_params['backpressure_detection'] = backpressure virtual_host = '/' query_string = urlencode(test_params) - test_url = ('https://myuser:mypass@www.test.com:5678/%s?%s' % ( + test_url = ('amqps://myuser:mypass@www.test.com:5678/%s?%s' % ( url_quote(virtual_host, safe=''), query_string, )) @@ -632,13 +637,17 @@ expected_value = query_args[t_param] actual_value = getattr(params, t_param) - self.assertEqual( - actual_value, - expected_value, - msg='Expected %s=%r, but got %r' % - (t_param, expected_value, actual_value)) + if t_param == 'ssl_options': + self.assertEqual(actual_value.server_hostname, + expected_value['server_hostname']) + else: + self.assertEqual( + actual_value, + expected_value, + msg='Expected %s=%r, but got %r' % + (t_param, expected_value, actual_value)) self.assertEqual(params.backpressure_detection, backpressure == 't') # check all values from base URL @@ -640,9 +649,11 @@ self.assertEqual(params.backpressure_detection, backpressure == 't') # check all values from base URL - self.assertEqual(params.ssl, True) + self.assertIsNotNone(params.ssl_options) + self.assertIsNotNone(params.ssl_options.context) + self.assertIsInstance(params.ssl_options.context, ssl.SSLContext) self.assertEqual(params.credentials.username, 'myuser') self.assertEqual(params.credentials.password, 'mypass') self.assertEqual(params.host, 'www.test.com')