Bug 710345: Upgrade pywebsocket to v606 (RFC 6455). r=mcmanus

This commit is contained in:
Jason Duell 2011-12-20 00:20:00 -08:00
parent ac1d404f8b
commit 2d4bdcf04b
14 changed files with 464 additions and 175 deletions

View File

@ -65,6 +65,12 @@ Installation:
PythonOption mod_pywebsocket.allow_draft75 On PythonOption mod_pywebsocket.allow_draft75 On
If you want to allow handlers whose canonical path is not under the root
directory (i.e. symbolic link is in root directory but its target is not),
configure as follows:
PythonOption mod_pywebsocket.allow_handlers_outside_root_dir On
Example snippet of httpd.conf: Example snippet of httpd.conf:
(mod_pywebsocket is in /websock_lib, WebSocket handlers are in (mod_pywebsocket is in /websock_lib, WebSocket handlers are in
/websock_handlers, port is 80 for ws, 443 for wss.) /websock_handlers, port is 80 for ws, 443 for wss.)

View File

@ -111,14 +111,9 @@ class StreamBase(object):
bytes = self._request.connection.read(length) bytes = self._request.connection.read(length)
if not bytes: if not bytes:
# MOZILLA: Patrick McManus found we needed this for Python 2.5 to raise ConnectionTerminatedException(
# work. Not sure which tests he meant: I found that 'Receiving %d byte failed. Peer (%r) closed connection' %
# content/base/test/test_websocket* all worked fine with 2.5 with (length, (self._request.connection.remote_addr,)))
# the original Google code. JDuell
#raise ConnectionTerminatedException(
# 'Receiving %d byte failed. Peer (%r) closed connection' %
# (length, (self._request.connection.remote_addr,)))
raise ConnectionTerminatedException('connection terminated: read failed')
return bytes return bytes
def _write(self, bytes): def _write(self, bytes):

View File

@ -298,9 +298,11 @@ class Stream(StreamBase):
'Mask bit on the received frame did\'nt match masking ' 'Mask bit on the received frame did\'nt match masking '
'configuration for received frames') 'configuration for received frames')
# The spec doesn't disallow putting a value in 0x0-0xFFFF into the # The Hybi-13 and later specs disallow putting a value in 0x0-0xFFFF
# 8-octet extended payload length field (or 0x0-0xFD in 2-octet field). # into the 8-octet extended payload length field (or 0x0-0xFD in
# So, we don't check the range of extended_payload_length. # 2-octet field).
valid_length_encoding = True
length_encoding_bytes = 1
if payload_length == 127: if payload_length == 127:
extended_payload_length = self.receive_bytes(8) extended_payload_length = self.receive_bytes(8)
payload_length = struct.unpack( payload_length = struct.unpack(
@ -308,10 +310,23 @@ class Stream(StreamBase):
if payload_length > 0x7FFFFFFFFFFFFFFF: if payload_length > 0x7FFFFFFFFFFFFFFF:
raise InvalidFrameException( raise InvalidFrameException(
'Extended payload length >= 2^63') 'Extended payload length >= 2^63')
if self._request.ws_version >= 13 and payload_length < 0x10000:
valid_length_encoding = False
length_encoding_bytes = 8
elif payload_length == 126: elif payload_length == 126:
extended_payload_length = self.receive_bytes(2) extended_payload_length = self.receive_bytes(2)
payload_length = struct.unpack( payload_length = struct.unpack(
'!H', extended_payload_length)[0] '!H', extended_payload_length)[0]
if self._request.ws_version >= 13 and payload_length < 126:
valid_length_encoding = False
length_encoding_bytes = 2
if not valid_length_encoding:
self._logger.warning(
'Payload length is not encoded using the minimal number of '
'bytes (%d is encoded using %d bytes)',
payload_length,
length_encoding_bytes)
if mask == 1: if mask == 1:
masking_nonce = self.receive_bytes(4) masking_nonce = self.receive_bytes(4)

View File

@ -41,9 +41,16 @@ VERSION_HYBI07 = 7
VERSION_HYBI08 = 8 VERSION_HYBI08 = 8
VERSION_HYBI09 = 8 VERSION_HYBI09 = 8
VERSION_HYBI10 = 8 VERSION_HYBI10 = 8
VERSION_HYBI11 = 8
VERSION_HYBI12 = 8
VERSION_HYBI13 = 13
VERSION_HYBI14 = 13
VERSION_HYBI15 = 13
VERSION_HYBI16 = 13
VERSION_HYBI17 = 13
# Constants indicating WebSocket protocol latest version. # Constants indicating WebSocket protocol latest version.
VERSION_HYBI_LATEST = VERSION_HYBI10 VERSION_HYBI_LATEST = VERSION_HYBI13
# Port numbers # Port numbers
DEFAULT_WEB_SOCKET_PORT = 80 DEFAULT_WEB_SOCKET_PORT = 80
@ -95,10 +102,17 @@ STATUS_NORMAL = 1000
STATUS_GOING_AWAY = 1001 STATUS_GOING_AWAY = 1001
STATUS_PROTOCOL_ERROR = 1002 STATUS_PROTOCOL_ERROR = 1002
STATUS_UNSUPPORTED = 1003 STATUS_UNSUPPORTED = 1003
STATUS_TOO_LARGE = 1004
STATUS_CODE_NOT_AVAILABLE = 1005 STATUS_CODE_NOT_AVAILABLE = 1005
STATUS_ABNORMAL_CLOSE = 1006 STATUS_ABNORMAL_CLOSE = 1006
STATUS_INVALID_UTF8 = 1007 STATUS_INVALID_FRAME_PAYLOAD = 1007
STATUS_POLICY_VIOLATION = 1008
STATUS_MESSAGE_TOO_BIG = 1009
STATUS_MANDATORY_EXT = 1010
# HTTP status codes
HTTP_STATUS_BAD_REQUEST = 400
HTTP_STATUS_FORBIDDEN = 403
HTTP_STATUS_NOT_FOUND = 404
def is_control_opcode(opcode): def is_control_opcode(opcode):

View File

@ -54,13 +54,14 @@ _PASSIVE_CLOSING_HANDSHAKE_HANDLER_NAME = (
class DispatchException(Exception): class DispatchException(Exception):
"""Exception in dispatching WebSocket request.""" """Exception in dispatching WebSocket request."""
def __init__(self, name, status=404): def __init__(self, name, status=common.HTTP_STATUS_NOT_FOUND):
super(DispatchException, self).__init__(name) super(DispatchException, self).__init__(name)
self.status = status self.status = status
def _default_passive_closing_handshake_handler(request): def _default_passive_closing_handshake_handler(request):
"""Default web_socket_passive_closing_handshake handler.""" """Default web_socket_passive_closing_handshake handler."""
return common.STATUS_NORMAL, '' return common.STATUS_NORMAL, ''
@ -76,15 +77,21 @@ def _normalize_path(path):
""" """
path = path.replace('\\', os.path.sep) path = path.replace('\\', os.path.sep)
path = os.path.realpath(path)
# MOZILLA: do not normalize away symlinks in mochitest
#path = os.path.realpath(path)
path = path.replace('\\', '/') path = path.replace('\\', '/')
return path return path
def _create_path_to_resource_converter(base_dir): def _create_path_to_resource_converter(base_dir):
"""Returns a function that converts the path of a WebSocket handler source
file to a resource string by removing the path to the base directory from
its head, removing _SOURCE_SUFFIX from its tail, and replacing path
separators in it with '/'.
Args:
base_dir: the path to the base directory.
"""
base_dir = _normalize_path(base_dir) base_dir = _normalize_path(base_dir)
base_len = len(base_dir) base_len = len(base_dir)
@ -93,7 +100,9 @@ def _create_path_to_resource_converter(base_dir):
def converter(path): def converter(path):
if not path.endswith(_SOURCE_SUFFIX): if not path.endswith(_SOURCE_SUFFIX):
return None return None
path = _normalize_path(path) # _normalize_path must not be used because resolving symlink breaks
# following path check.
path = path.replace('\\', '/')
if not path.startswith(base_dir): if not path.startswith(base_dir):
return None return None
return path[base_len:-suffix_len] return path[base_len:-suffix_len]
@ -169,7 +178,9 @@ class Dispatcher(object):
This class maintains a map from resource name to handlers. This class maintains a map from resource name to handlers.
""" """
def __init__(self, root_dir, scan_dir=None): def __init__(
self, root_dir, scan_dir=None,
allow_handlers_outside_root_dir=True):
"""Construct an instance. """Construct an instance.
Args: Args:
@ -181,6 +192,8 @@ class Dispatcher(object):
root_dir is used as scan_dir. scan_dir can be useful root_dir is used as scan_dir. scan_dir can be useful
in saving scan time when root_dir contains many in saving scan time when root_dir contains many
subdirectories. subdirectories.
allow_handlers_outside_root_dir: Scans handler files even if their
canonical path is not under root_dir.
""" """
self._logger = util.get_class_logger(self) self._logger = util.get_class_logger(self)
@ -193,7 +206,8 @@ class Dispatcher(object):
os.path.realpath(root_dir)): os.path.realpath(root_dir)):
raise DispatchException('scan_dir:%s must be a directory under ' raise DispatchException('scan_dir:%s must be a directory under '
'root_dir:%s.' % (scan_dir, root_dir)) 'root_dir:%s.' % (scan_dir, root_dir))
self._source_handler_files_in_dir(root_dir, scan_dir) self._source_handler_files_in_dir(
root_dir, scan_dir, allow_handlers_outside_root_dir)
def add_resource_path_alias(self, def add_resource_path_alias(self,
alias_resource_path, existing_resource_path): alias_resource_path, existing_resource_path):
@ -247,7 +261,7 @@ class Dispatcher(object):
_DO_EXTRA_HANDSHAKE_HANDLER_NAME, _DO_EXTRA_HANDSHAKE_HANDLER_NAME,
request.ws_resource), request.ws_resource),
e) e)
raise handshake.HandshakeException(e, 403) raise handshake.HandshakeException(e, common.HTTP_STATUS_FORBIDDEN)
def transfer_data(self, request): def transfer_data(self, request):
"""Let a handler transfer_data with a WebSocket client. """Let a handler transfer_data with a WebSocket client.
@ -288,8 +302,9 @@ class Dispatcher(object):
self._logger.debug('%s', e) self._logger.debug('%s', e)
request.ws_stream.close_connection(common.STATUS_UNSUPPORTED) request.ws_stream.close_connection(common.STATUS_UNSUPPORTED)
except stream.InvalidUTF8Exception, e: except stream.InvalidUTF8Exception, e:
self._logger_debug('%s', e) self._logger.debug('%s', e)
request.ws_stream.close_connection(common.STATUS_INVALID_UTF8) request.ws_stream.close_connection(
common.STATUS_INVALID_FRAME_PAYLOAD)
except msgutil.ConnectionTerminatedException, e: except msgutil.ConnectionTerminatedException, e:
self._logger.debug('%s', e) self._logger.debug('%s', e)
except Exception, e: except Exception, e:
@ -322,23 +337,45 @@ class Dispatcher(object):
handler_suite = self._handler_suite_map.get(resource) handler_suite = self._handler_suite_map.get(resource)
if handler_suite and fragment: if handler_suite and fragment:
raise DispatchException('Fragment identifiers MUST NOT be used on ' raise DispatchException('Fragment identifiers MUST NOT be used on '
'WebSocket URIs', 400); 'WebSocket URIs',
common.HTTP_STATUS_BAD_REQUEST)
return handler_suite return handler_suite
def _source_handler_files_in_dir(self, root_dir, scan_dir): def _source_handler_files_in_dir(
self, root_dir, scan_dir, allow_handlers_outside_root_dir):
"""Source all the handler source files in the scan_dir directory. """Source all the handler source files in the scan_dir directory.
The resource path is determined relative to root_dir. The resource path is determined relative to root_dir.
""" """
# We build a map from resource to handler code assuming that there's
# only one path from root_dir to scan_dir and it can be obtained by
# comparing realpath of them.
# Here we cannot use abspath. See
# https://bugs.webkit.org/show_bug.cgi?id=31603
convert = _create_path_to_resource_converter(root_dir) convert = _create_path_to_resource_converter(root_dir)
for path in _enumerate_handler_file_paths(scan_dir): scan_realpath = os.path.realpath(scan_dir)
root_realpath = os.path.realpath(root_dir)
for path in _enumerate_handler_file_paths(scan_realpath):
if (not allow_handlers_outside_root_dir and
(not os.path.realpath(path).startswith(root_realpath))):
self._logger.debug(
'Canonical path of %s is not under root directory' %
path)
continue
try: try:
handler_suite = _source_handler_file(open(path).read()) handler_suite = _source_handler_file(open(path).read())
except DispatchException, e: except DispatchException, e:
self._source_warnings.append('%s: %s' % (path, e)) self._source_warnings.append('%s: %s' % (path, e))
continue continue
self._handler_suite_map[convert(path)] = handler_suite resource = convert(path)
if resource is None:
self._logger.debug(
'Path to resource conversion on %s failed' % path)
else:
self._handler_suite_map[convert(path)] = handler_suite
# vi:sts=4 sw=4 et # vi:sts=4 sw=4 et

View File

@ -36,6 +36,7 @@ _available_processors = {}
class ExtensionProcessorInterface(object): class ExtensionProcessorInterface(object):
def get_extension_response(self): def get_extension_response(self):
return None return None
@ -131,7 +132,9 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
return response return response
def setup_stream_options(self, stream_options): def setup_stream_options(self, stream_options):
class _OutgoingFilter(object): class _OutgoingFilter(object):
def __init__(self, parent): def __init__(self, parent):
self._parent = parent self._parent = parent
@ -139,6 +142,7 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
self._parent._outgoing_filter(frame) self._parent._outgoing_filter(frame)
class _IncomingFilter(object): class _IncomingFilter(object):
def __init__(self, parent): def __init__(self, parent):
self._parent = parent self._parent = parent

View File

@ -36,12 +36,15 @@ successfully established.
import logging import logging
from mod_pywebsocket import common
from mod_pywebsocket.handshake import draft75 from mod_pywebsocket.handshake import draft75
from mod_pywebsocket.handshake import hybi00 from mod_pywebsocket.handshake import hybi00
from mod_pywebsocket.handshake import hybi from mod_pywebsocket.handshake import hybi
# Export AbortedByUserException and HandshakeException symbol from this module. # Export AbortedByUserException, HandshakeException, and VersionException
# symbol from this module.
from mod_pywebsocket.handshake._base import AbortedByUserException from mod_pywebsocket.handshake._base import AbortedByUserException
from mod_pywebsocket.handshake._base import HandshakeException from mod_pywebsocket.handshake._base import HandshakeException
from mod_pywebsocket.handshake._base import VersionException
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -62,7 +65,7 @@ def do_handshake(request, dispatcher, allowDraft75=False, strict=False):
handshake. handshake.
""" """
_LOGGER.debug('Opening handshake resource: %r', request.uri) _LOGGER.debug('Client\'s opening handshake resource: %r', request.uri)
# To print mimetools.Message as escaped one-line string, we converts # To print mimetools.Message as escaped one-line string, we converts
# headers_in to dict object. Without conversion, if we use %r, it just # headers_in to dict object. Without conversion, if we use %r, it just
# prints the type and address, and if we use %s, it prints the original # prints the type and address, and if we use %s, it prints the original
@ -76,7 +79,7 @@ def do_handshake(request, dispatcher, allowDraft75=False, strict=False):
# header values. While MpTable_Type doesn't have such __str__ but just # header values. While MpTable_Type doesn't have such __str__ but just
# __repr__ which formats itself as well as dictionary object. # __repr__ which formats itself as well as dictionary object.
_LOGGER.debug( _LOGGER.debug(
'Opening handshake request headers: %r', dict(request.headers_in)) 'Client\'s opening handshake headers: %r', dict(request.headers_in))
handshakers = [] handshakers = []
handshakers.append( handshakers.append(
@ -88,21 +91,26 @@ def do_handshake(request, dispatcher, allowDraft75=False, strict=False):
('IETF Hixie 75', draft75.Handshaker(request, dispatcher, strict))) ('IETF Hixie 75', draft75.Handshaker(request, dispatcher, strict)))
for name, handshaker in handshakers: for name, handshaker in handshakers:
_LOGGER.info('Trying %s protocol', name) _LOGGER.debug('Trying %s protocol', name)
try: try:
handshaker.do_handshake() handshaker.do_handshake()
_LOGGER.info('Established (%s protocol)', name)
return return
except HandshakeException, e: except HandshakeException, e:
_LOGGER.info( _LOGGER.debug(
'Failed to complete opening handshake as %s protocol: %r', 'Failed to complete opening handshake as %s protocol: %r',
name, e) name, e)
if e.status: if e.status:
raise e raise e
except AbortedByUserException, e: except AbortedByUserException, e:
raise raise
except VersionException, e:
raise
# TODO(toyoshim): Add a test to cover the case all handshakers fail.
raise HandshakeException( raise HandshakeException(
'Failed to complete opening handshake for all available protocols') 'Failed to complete opening handshake for all available protocols',
status=common.HTTP_STATUS_BAD_REQUEST)
# vi:sts=4 sw=4 et # vi:sts=4 sw=4 et

View File

@ -61,6 +61,22 @@ class HandshakeException(Exception):
self.status = status self.status = status
class VersionException(Exception):
"""This exception will be raised when a version of client request does not
match with version the server supports.
"""
def __init__(self, name, supported_versions=''):
"""Construct an instance.
Args:
supported_version: a str object to show supported hybi versions.
(e.g. '8, 13')
"""
super(VersionException, self).__init__(name)
self.supported_versions = supported_versions
def get_default_port(is_secure): def get_default_port(is_secure):
if is_secure: if is_secure:
return common.DEFAULT_WEB_SOCKET_SECURE_PORT return common.DEFAULT_WEB_SOCKET_SECURE_PORT
@ -200,7 +216,7 @@ def parse_token_list(data):
return token_list return token_list
def _parse_extension_param(state, definition): def _parse_extension_param(state, definition, allow_quoted_string):
param_name = http_header_util.consume_token(state) param_name = http_header_util.consume_token(state)
if param_name is None: if param_name is None:
@ -214,7 +230,11 @@ def _parse_extension_param(state, definition):
http_header_util.consume_lwses(state) http_header_util.consume_lwses(state)
param_value = http_header_util.consume_token_or_quoted_string(state) if allow_quoted_string:
# TODO(toyoshim): Add code to validate that parsed param_value is token
param_value = http_header_util.consume_token_or_quoted_string(state)
else:
param_value = http_header_util.consume_token(state)
if param_value is None: if param_value is None:
raise HandshakeException( raise HandshakeException(
'No valid parameter value found on the right-hand side of ' 'No valid parameter value found on the right-hand side of '
@ -223,7 +243,7 @@ def _parse_extension_param(state, definition):
definition.add_parameter(param_name, param_value) definition.add_parameter(param_name, param_value)
def _parse_extension(state): def _parse_extension(state, allow_quoted_string):
extension_token = http_header_util.consume_token(state) extension_token = http_header_util.consume_token(state)
if extension_token is None: if extension_token is None:
return None return None
@ -239,7 +259,7 @@ def _parse_extension(state):
http_header_util.consume_lwses(state) http_header_util.consume_lwses(state)
try: try:
_parse_extension_param(state, extension) _parse_extension_param(state, extension, allow_quoted_string)
except HandshakeException, e: except HandshakeException, e:
raise HandshakeException( raise HandshakeException(
'Failed to parse Sec-WebSocket-Extensions header: ' 'Failed to parse Sec-WebSocket-Extensions header: '
@ -249,7 +269,7 @@ def _parse_extension(state):
return extension return extension
def parse_extensions(data): def parse_extensions(data, allow_quoted_string=False):
"""Parses Sec-WebSocket-Extensions header value returns a list of """Parses Sec-WebSocket-Extensions header value returns a list of
common.ExtensionParameter objects. common.ExtensionParameter objects.
@ -260,7 +280,7 @@ def parse_extensions(data):
extension_list = [] extension_list = []
while True: while True:
extension = _parse_extension(state) extension = _parse_extension(state, allow_quoted_string)
if extension is not None: if extension is not None:
extension_list.append(extension) extension_list.append(extension)

View File

@ -53,6 +53,7 @@ from mod_pywebsocket.handshake._base import parse_extensions
from mod_pywebsocket.handshake._base import parse_token_list from mod_pywebsocket.handshake._base import parse_token_list
from mod_pywebsocket.handshake._base import validate_mandatory_header from mod_pywebsocket.handshake._base import validate_mandatory_header
from mod_pywebsocket.handshake._base import validate_subprotocol from mod_pywebsocket.handshake._base import validate_subprotocol
from mod_pywebsocket.handshake._base import VersionException
from mod_pywebsocket.stream import Stream from mod_pywebsocket.stream import Stream
from mod_pywebsocket.stream import StreamOptions from mod_pywebsocket.stream import StreamOptions
from mod_pywebsocket import util from mod_pywebsocket import util
@ -60,6 +61,16 @@ from mod_pywebsocket import util
_BASE64_REGEX = re.compile('^[+/0-9A-Za-z]*=*$') _BASE64_REGEX = re.compile('^[+/0-9A-Za-z]*=*$')
# Defining aliases for values used frequently.
_VERSION_HYBI08 = common.VERSION_HYBI08
_VERSION_HYBI08_STRING = str(_VERSION_HYBI08)
_VERSION_LATEST = common.VERSION_HYBI_LATEST
_VERSION_LATEST_STRING = str(_VERSION_LATEST)
_SUPPORTED_VERSIONS = [
_VERSION_LATEST,
_VERSION_HYBI08,
]
def compute_accept(key): def compute_accept(key):
"""Computes value for the Sec-WebSocket-Accept header from value of the """Computes value for the Sec-WebSocket-Accept header from value of the
@ -130,7 +141,7 @@ class Handshaker(object):
unused_host = get_mandatory_header(self._request, common.HOST_HEADER) unused_host = get_mandatory_header(self._request, common.HOST_HEADER)
self._check_version() self._request.ws_version = self._check_version()
# This handshake must be based on latest hybi. We are responsible to # This handshake must be based on latest hybi. We are responsible to
# fallback to HTTP on handshake failure as latest hybi handshake # fallback to HTTP on handshake failure as latest hybi handshake
@ -151,7 +162,6 @@ class Handshaker(object):
util.hexify(accept_binary)) util.hexify(accept_binary))
self._logger.debug('IETF HyBi protocol') self._logger.debug('IETF HyBi protocol')
self._request.ws_version = common.VERSION_HYBI_LATEST
# Setup extension processors. # Setup extension processors.
@ -212,29 +222,42 @@ class Handshaker(object):
'request any subprotocol') 'request any subprotocol')
self._send_handshake(accept) self._send_handshake(accept)
self._logger.debug('Sent opening handshake response')
except HandshakeException, e: except HandshakeException, e:
if not e.status: if not e.status:
# Fallback to 400 bad request by default. # Fallback to 400 bad request by default.
e.status = 400 e.status = common.HTTP_STATUS_BAD_REQUEST
raise e raise e
def _get_origin(self): def _get_origin(self):
origin = self._request.headers_in.get( if self._request.ws_version is _VERSION_HYBI08:
common.SEC_WEBSOCKET_ORIGIN_HEADER) origin_header = common.SEC_WEBSOCKET_ORIGIN_HEADER
else:
origin_header = common.ORIGIN_HEADER
origin = self._request.headers_in.get(origin_header)
if origin is None:
self._logger.debug('Client request does not have origin header')
self._request.ws_origin = origin self._request.ws_origin = origin
def _check_version(self): def _check_version(self):
unused_value = validate_mandatory_header( version = get_mandatory_header(self._request,
self._request, common.SEC_WEBSOCKET_VERSION_HEADER, common.SEC_WEBSOCKET_VERSION_HEADER)
str(common.VERSION_HYBI_LATEST), fail_status=426) if version == _VERSION_HYBI08_STRING:
return _VERSION_HYBI08
if version == _VERSION_LATEST_STRING:
return _VERSION_LATEST
if version.find(',') >= 0:
raise HandshakeException(
'Multiple versions (%r) are not allowed for header %s' %
(version, common.SEC_WEBSOCKET_VERSION_HEADER),
status=common.HTTP_STATUS_BAD_REQUEST)
raise VersionException(
'Unsupported version %r for header %s' %
(version, common.SEC_WEBSOCKET_VERSION_HEADER),
supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS)))
def _set_protocol(self): def _set_protocol(self):
self._request.ws_protocol = None self._request.ws_protocol = None
# MOZILLA
self._request.sts = None
# /MOZILLA
protocol_header = self._request.headers_in.get( protocol_header = self._request.headers_in.get(
common.SEC_WEBSOCKET_PROTOCOL_HEADER) common.SEC_WEBSOCKET_PROTOCOL_HEADER)
@ -255,8 +278,12 @@ class Handshaker(object):
self._request.ws_requested_extensions = None self._request.ws_requested_extensions = None
return return
if self._request.ws_version is common.VERSION_HYBI08:
allow_quoted_string=False
else:
allow_quoted_string=True
self._request.ws_requested_extensions = parse_extensions( self._request.ws_requested_extensions = parse_extensions(
extensions_header) extensions_header, allow_quoted_string=allow_quoted_string)
self._logger.debug( self._logger.debug(
'Extensions requested: %r', 'Extensions requested: %r',
@ -264,6 +291,11 @@ class Handshaker(object):
self._request.ws_requested_extensions)) self._request.ws_requested_extensions))
def _validate_key(self, key): def _validate_key(self, key):
if key.find(',') >= 0:
raise HandshakeException('Request has multiple %s header lines or '
'contains illegal character \',\': %r' %
(common.SEC_WEBSOCKET_KEY_HEADER, key))
# Validate # Validate
key_is_valid = False key_is_valid = False
try: try:
@ -319,16 +351,12 @@ class Handshaker(object):
response.append(format_header( response.append(format_header(
common.SEC_WEBSOCKET_EXTENSIONS_HEADER, common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
format_extensions(self._request.ws_extensions))) format_extensions(self._request.ws_extensions)))
# MOZILLA: Add HSTS header if requested to
if self._request.sts is not None:
response.append(format_header("Strict-Transport-Security",
self._request.sts))
# /MOZILLA
response.append('\r\n') response.append('\r\n')
raw_response = ''.join(response) raw_response = ''.join(response)
self._logger.debug('Opening handshake response: %r', raw_response)
self._request.connection.write(raw_response) self._request.connection.write(raw_response)
self._logger.debug('Sent server\'s opening handshake: %r',
raw_response)
# vi:sts=4 sw=4 et # vi:sts=4 sw=4 et

View File

@ -107,8 +107,6 @@ class Handshaker(object):
self._send_handshake() self._send_handshake()
self._logger.debug('Sent opening handshake response')
def _set_resource(self): def _set_resource(self):
self._request.ws_resource = self._request.uri self._request.ws_resource = self._request.uri
@ -138,7 +136,8 @@ class Handshaker(object):
draft = self._request.headers_in.get(common.SEC_WEBSOCKET_DRAFT_HEADER) draft = self._request.headers_in.get(common.SEC_WEBSOCKET_DRAFT_HEADER)
if draft is not None and draft != '0': if draft is not None and draft != '0':
raise HandshakeException('Illegal value for %s: %s' % raise HandshakeException('Illegal value for %s: %s' %
(common.SEC_WEBSOCKET_DRAFT_HEADER, draft)) (common.SEC_WEBSOCKET_DRAFT_HEADER,
draft))
self._logger.debug('IETF HyBi 00 protocol') self._logger.debug('IETF HyBi 00 protocol')
self._request.ws_version = common.VERSION_HYBI00 self._request.ws_version = common.VERSION_HYBI00
@ -229,8 +228,9 @@ class Handshaker(object):
response.append(self._request.ws_challenge_md5) response.append(self._request.ws_challenge_md5)
raw_response = ''.join(response) raw_response = ''.join(response)
self._logger.debug('Opening handshake response: %r', raw_response)
self._request.connection.write(raw_response) self._request.connection.write(raw_response)
self._logger.debug('Sent server\'s opening handshake: %r',
raw_response)
# vi:sts=4 sw=4 et # vi:sts=4 sw=4 et

View File

@ -39,6 +39,7 @@ import logging
from mod_python import apache from mod_python import apache
from mod_pywebsocket import common
from mod_pywebsocket import dispatch from mod_pywebsocket import dispatch
from mod_pywebsocket import handshake from mod_pywebsocket import handshake
from mod_pywebsocket import util from mod_pywebsocket import util
@ -52,9 +53,21 @@ _PYOPT_HANDLER_ROOT = 'mod_pywebsocket.handler_root'
# The default is the root directory. # The default is the root directory.
_PYOPT_HANDLER_SCAN = 'mod_pywebsocket.handler_scan' _PYOPT_HANDLER_SCAN = 'mod_pywebsocket.handler_scan'
# PythonOption to allow handlers whose canonical path is
# not under the root directory. It's disallowed by default.
# Set this option with value of 'yes' to allow.
_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT = (
'mod_pywebsocket.allow_handlers_outside_root_dir')
# Map from values to their meanings. 'Yes' and 'No' are allowed just for
# compatibility.
_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT_DEFINITION = {
'off': False, 'no': False, 'on': True, 'yes': True}
# PythonOption to specify to allow draft75 handshake. # PythonOption to specify to allow draft75 handshake.
# The default is None (Off) # The default is None (Off)
_PYOPT_ALLOW_DRAFT75 = 'mod_pywebsocket.allow_draft75' _PYOPT_ALLOW_DRAFT75 = 'mod_pywebsocket.allow_draft75'
# Map from values to their meanings.
_PYOPT_ALLOW_DRAFT75_DEFINITION = {'off': False, 'on': True}
class ApacheLogHandler(logging.Handler): class ApacheLogHandler(logging.Handler):
@ -70,15 +83,20 @@ class ApacheLogHandler(logging.Handler):
def __init__(self, request=None): def __init__(self, request=None):
logging.Handler.__init__(self) logging.Handler.__init__(self)
self.log_error = apache.log_error self._log_error = apache.log_error
if request is not None: if request is not None:
self.log_error = request.log_error self._log_error = request.log_error
# Time and level will be printed by Apache.
self._formatter = logging.Formatter('%(name)s: %(message)s')
def emit(self, record): def emit(self, record):
apache_level = apache.APLOG_DEBUG apache_level = apache.APLOG_DEBUG
if record.levelno in ApacheLogHandler._LEVELS: if record.levelno in ApacheLogHandler._LEVELS:
apache_level = ApacheLogHandler._LEVELS[record.levelno] apache_level = ApacheLogHandler._LEVELS[record.levelno]
msg = self._formatter.format(record)
# "server" parameter must be passed to have "level" parameter work. # "server" parameter must be passed to have "level" parameter work.
# If only "level" parameter is passed, nothing shows up on Apache's # If only "level" parameter is passed, nothing shows up on Apache's
# log. However, at this point, we cannot get the server object of the # log. However, at this point, we cannot get the server object of the
@ -99,28 +117,57 @@ class ApacheLogHandler(logging.Handler):
# methods call request.log_error indirectly. When request is # methods call request.log_error indirectly. When request is
# _StandaloneRequest, the methods call Python's logging facility which # _StandaloneRequest, the methods call Python's logging facility which
# we create in standalone.py. # we create in standalone.py.
self.log_error(record.getMessage(), apache_level, apache.main_server) self._log_error(msg, apache_level, apache.main_server)
_LOGGER = logging.getLogger('mod_pywebsocket') def _configure_logging():
# Logs are filtered by Apache based on LogLevel directive in Apache logger = logging.getLogger()
# configuration file. We must just pass logs for all levels to # Logs are filtered by Apache based on LogLevel directive in Apache
# ApacheLogHandler. # configuration file. We must just pass logs for all levels to
_LOGGER.setLevel(logging.DEBUG) # ApacheLogHandler.
_LOGGER.addHandler(ApacheLogHandler()) logger.setLevel(logging.DEBUG)
logger.addHandler(ApacheLogHandler())
_configure_logging()
_LOGGER = logging.getLogger(__name__)
def _parse_option(name, value, definition):
if value is None:
return False
meaning = definition.get(value.lower())
if meaning is None:
raise Exception('Invalid value for PythonOption %s: %r' %
(name, value))
return meaning
def _create_dispatcher(): def _create_dispatcher():
_HANDLER_ROOT = apache.main_server.get_options().get( _LOGGER.info('Initializing Dispatcher')
_PYOPT_HANDLER_ROOT, None)
if not _HANDLER_ROOT: options = apache.main_server.get_options()
handler_root = options.get(_PYOPT_HANDLER_ROOT, None)
if not handler_root:
raise Exception('PythonOption %s is not defined' % _PYOPT_HANDLER_ROOT, raise Exception('PythonOption %s is not defined' % _PYOPT_HANDLER_ROOT,
apache.APLOG_ERR) apache.APLOG_ERR)
_HANDLER_SCAN = apache.main_server.get_options().get(
_PYOPT_HANDLER_SCAN, _HANDLER_ROOT) handler_scan = options.get(_PYOPT_HANDLER_SCAN, handler_root)
dispatcher = dispatch.Dispatcher(_HANDLER_ROOT, _HANDLER_SCAN)
allow_handlers_outside_root = _parse_option(
_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT,
options.get(_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT),
_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT_DEFINITION)
dispatcher = dispatch.Dispatcher(
handler_root, handler_scan, allow_handlers_outside_root)
for warning in dispatcher.source_warnings(): for warning in dispatcher.source_warnings():
apache.log_error('mod_pywebsocket: %s' % warning, apache.APLOG_WARNING) apache.log_error('mod_pywebsocket: %s' % warning, apache.APLOG_WARNING)
return dispatcher return dispatcher
@ -140,33 +187,54 @@ def headerparserhandler(request):
handshake_is_done = False handshake_is_done = False
try: try:
allowDraft75 = apache.main_server.get_options().get( # Fallback to default http handler for request paths for which
_PYOPT_ALLOW_DRAFT75, None) # we don't have request handlers.
handshake.do_handshake( if not _dispatcher.get_handler_suite(request.uri):
request, _dispatcher, allowDraft75=allowDraft75) request.log_error('No handler for resource: %r' % request.uri,
handshake_is_done = True apache.APLOG_INFO)
request.log_error( request.log_error('Fallback to Apache', apache.APLOG_INFO)
'mod_pywebsocket: resource: %r' % request.ws_resource, return apache.DECLINED
apache.APLOG_DEBUG)
request._dispatcher = _dispatcher
_dispatcher.transfer_data(request)
except dispatch.DispatchException, e: except dispatch.DispatchException, e:
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_WARNING) request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
if not handshake_is_done: if not handshake_is_done:
return e.status return e.status
try:
allow_draft75 = _parse_option(
_PYOPT_ALLOW_DRAFT75,
apache.main_server.get_options().get(_PYOPT_ALLOW_DRAFT75),
_PYOPT_ALLOW_DRAFT75_DEFINITION)
try:
handshake.do_handshake(
request, _dispatcher, allowDraft75=allow_draft75)
except handshake.VersionException, e:
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
request.err_headers_out.add(common.SEC_WEBSOCKET_VERSION_HEADER,
e.supported_versions)
return apache.HTTP_BAD_REQUEST
except handshake.HandshakeException, e:
# Handshake for ws/wss failed.
# Send http response with error status.
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
return e.status
handshake_is_done = True
request._dispatcher = _dispatcher
_dispatcher.transfer_data(request)
except handshake.AbortedByUserException, e: except handshake.AbortedByUserException, e:
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
except handshake.HandshakeException, e:
# Handshake for ws/wss failed.
# The request handling fallback into http/https.
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
return e.status
except Exception, e: except Exception, e:
request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_WARNING) # DispatchException can also be thrown if something is wrong in
# pywebsocket code. It's caught here, then.
request.log_error('mod_pywebsocket: %s\n%s' %
(e, util.get_stack_trace()),
apache.APLOG_ERR)
# Unknown exceptions before handshake mean Apache must handle its # Unknown exceptions before handshake mean Apache must handle its
# request with another handler. # request with another handler.
if not handshake_is_done: if not handshake_is_done:
return apache.DECLINE return apache.DECLINED
# Set assbackwards to suppress response header generation by Apache. # Set assbackwards to suppress response header generation by Apache.
request.assbackwards = 1 request.assbackwards = 1
return apache.DONE # Return DONE such that no other handlers are invoked. return apache.DONE # Return DONE such that no other handlers are invoked.

View File

@ -52,6 +52,7 @@ def _is_ctl(c):
class ParsingState(object): class ParsingState(object):
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
self.head = 0 self.head = 0
@ -209,7 +210,7 @@ def quote_if_necessary(s):
result.append(c) result.append(c)
if quote: if quote:
return '"' + ''.join(result) + '"'; return '"' + ''.join(result) + '"'
else: else:
return ''.join(result) return ''.join(result)
@ -251,4 +252,12 @@ def parse_uri(uri):
return parsed.hostname, port, path return parsed.hostname, port, path
try:
urlparse.uses_netloc.index('ws')
except ValueError, e:
# urlparse in Python2.5.1 doesn't have 'ws' and 'wss' entries.
urlparse.uses_netloc.append('ws')
urlparse.uses_netloc.append('wss')
# vi:sts=4 sw=4 et # vi:sts=4 sw=4 et

View File

@ -218,6 +218,7 @@ class DeflateRequest(object):
class _Deflater(object): class _Deflater(object):
def __init__(self, window_bits): def __init__(self, window_bits):
self._logger = get_class_logger(self) self._logger = get_class_logger(self)
@ -233,6 +234,7 @@ class _Deflater(object):
class _Inflater(object): class _Inflater(object):
def __init__(self): def __init__(self):
self._logger = get_class_logger(self) self._logger = get_class_logger(self)
@ -390,6 +392,10 @@ class DeflateConnection(object):
self._deflater = _Deflater(zlib.MAX_WBITS) self._deflater = _Deflater(zlib.MAX_WBITS)
self._inflater = _Inflater() self._inflater = _Inflater()
def get_remote_addr(self):
return self._connection.remote_addr
remote_addr = property(get_remote_addr)
def put_bytes(self, bytes): def put_bytes(self, bytes):
self.write(bytes) self.write(bytes)

View File

@ -65,6 +65,7 @@ import BaseHTTPServer
import CGIHTTPServer import CGIHTTPServer
import SimpleHTTPServer import SimpleHTTPServer
import SocketServer import SocketServer
import httplib
import logging import logging
import logging.handlers import logging.handlers
import optparse import optparse
@ -74,6 +75,8 @@ import select
import socket import socket
import sys import sys
import threading import threading
import time
_HAS_OPEN_SSL = False _HAS_OPEN_SSL = False
try: try:
@ -99,13 +102,6 @@ _DEFAULT_REQUEST_QUEUE_SIZE = 128
_MAX_MEMORIZED_LINES = 1024 _MAX_MEMORIZED_LINES = 1024
def _print_warnings_if_any(dispatcher):
warnings = dispatcher.source_warnings()
if warnings:
for warning in warnings:
logging.warning('mod_pywebsocket: %s' % warning)
class _StandaloneConnection(object): class _StandaloneConnection(object):
"""Mimic mod_python mp_conn.""" """Mimic mod_python mp_conn."""
@ -165,6 +161,7 @@ class _StandaloneRequest(object):
self._request_handler = request_handler self._request_handler = request_handler
self.connection = _StandaloneConnection(request_handler) self.connection = _StandaloneConnection(request_handler)
self._use_tls = use_tls self._use_tls = use_tls
self.headers_in = request_handler.headers
def get_uri(self): def get_uri(self):
"""Getter to mimic request.uri.""" """Getter to mimic request.uri."""
@ -178,12 +175,6 @@ class _StandaloneRequest(object):
return self._request_handler.command return self._request_handler.command
method = property(get_method) method = property(get_method)
def get_headers_in(self):
"""Getter to mimic request.headers_in."""
return self._request_handler.headers
headers_in = property(get_headers_in)
def is_https(self): def is_https(self):
"""Mimic request.is_https().""" """Mimic request.is_https()."""
@ -216,6 +207,8 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
if necessary. if necessary.
""" """
self._logger = util.get_class_logger(self)
self.request_queue_size = options.request_queue_size self.request_queue_size = options.request_queue_size
self.__ws_is_shut_down = threading.Event() self.__ws_is_shut_down = threading.Event()
self.__ws_serving = False self.__ws_serving = False
@ -235,8 +228,16 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
self.server_name, self.server_port = self.server_address self.server_name, self.server_port = self.server_address
self._sockets = [] self._sockets = []
if not self.server_name: if not self.server_name:
# On platforms that doesn't support IPv6, the first bind fails.
# On platforms that supports IPv6
# - If it binds both IPv4 and IPv6 on call with AF_INET6, the
# first bind succeeds and the second fails (we'll see 'Address
# already in use' error).
# - If it binds only IPv6 on call with AF_INET6, both call are
# expected to succeed to listen both protocol.
addrinfo_array = [ addrinfo_array = [
(self.address_family, self.socket_type, '', '', '')] (socket.AF_INET6, socket.SOCK_STREAM, '', '', ''),
(socket.AF_INET, socket.SOCK_STREAM, '', '', '')]
else: else:
addrinfo_array = socket.getaddrinfo(self.server_name, addrinfo_array = socket.getaddrinfo(self.server_name,
self.server_port, self.server_port,
@ -244,12 +245,12 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
socket.SOCK_STREAM, socket.SOCK_STREAM,
socket.IPPROTO_TCP) socket.IPPROTO_TCP)
for addrinfo in addrinfo_array: for addrinfo in addrinfo_array:
logging.info('Create socket on: %r', addrinfo) self._logger.info('Create socket on: %r', addrinfo)
family, socktype, proto, canonname, sockaddr = addrinfo family, socktype, proto, canonname, sockaddr = addrinfo
try: try:
socket_ = socket.socket(family, socktype) socket_ = socket.socket(family, socktype)
except Exception, e: except Exception, e:
logging.info('Skip by failure: %r', e) self._logger.info('Skip by failure: %r', e)
continue continue
if self.websocket_server_options.use_tls: if self.websocket_server_options.use_tls:
ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
@ -265,11 +266,22 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
sockets bind. sockets bind.
""" """
for socket_, addrinfo in self._sockets: failed_sockets = []
logging.info('Bind on: %r', addrinfo)
for socketinfo in self._sockets:
socket_, addrinfo = socketinfo
self._logger.info('Bind on: %r', addrinfo)
if self.allow_reuse_address: if self.allow_reuse_address:
socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
socket_.bind(self.server_address) try:
socket_.bind(self.server_address)
except Exception, e:
self._logger.info('Skip by failure: %r', e)
socket_.close()
failed_sockets.append(socketinfo)
for socketinfo in failed_sockets:
self._sockets.remove(socketinfo)
def server_activate(self): def server_activate(self):
"""Override SocketServer.TCPServer.server_activate to enable multiple """Override SocketServer.TCPServer.server_activate to enable multiple
@ -280,11 +292,11 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
for socketinfo in self._sockets: for socketinfo in self._sockets:
socket_, addrinfo = socketinfo socket_, addrinfo = socketinfo
logging.info('Listen on: %r', addrinfo) self._logger.info('Listen on: %r', addrinfo)
try: try:
socket_.listen(self.request_queue_size) socket_.listen(self.request_queue_size)
except Exception, e: except Exception, e:
logging.info('Skip by failure: %r', e) self._logger.info('Skip by failure: %r', e)
socket_.close() socket_.close()
failed_sockets.append(socketinfo) failed_sockets.append(socketinfo)
@ -298,23 +310,23 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
for socketinfo in self._sockets: for socketinfo in self._sockets:
socket_, addrinfo = socketinfo socket_, addrinfo = socketinfo
logging.info('Close on: %r', addrinfo) self._logger.info('Close on: %r', addrinfo)
socket_.close() socket_.close()
def fileno(self): def fileno(self):
"""Override SocketServer.TCPServer.fileno.""" """Override SocketServer.TCPServer.fileno."""
logging.critical('Not supported: fileno') self._logger.critical('Not supported: fileno')
return self._sockets[0][0].fileno() return self._sockets[0][0].fileno()
def handle_error(self, rquest, client_address): def handle_error(self, rquest, client_address):
"""Override SocketServer.handle_error.""" """Override SocketServer.handle_error."""
logging.error( self._logger.error(
('Exception in processing request from: %r' % (client_address,)) + 'Exception in processing request from: %r\n%s',
'\n' + util.get_stack_trace()) client_address,
# Note: client_address is a tuple. To match it against %r, we need the util.get_stack_trace())
# trailing comma. # Note: client_address is a tuple.
def serve_forever(self, poll_interval=0.5): def serve_forever(self, poll_interval=0.5):
"""Override SocketServer.BaseServer.serve_forever.""" """Override SocketServer.BaseServer.serve_forever."""
@ -325,8 +337,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
if hasattr(self, '_handle_request_noblock'): if hasattr(self, '_handle_request_noblock'):
handle_request = self._handle_request_noblock handle_request = self._handle_request_noblock
else: else:
logging.warning('mod_pywebsocket: fallback to blocking request ' self._logger.warning('Fallback to blocking request handler')
'handler')
try: try:
while self.__ws_serving: while self.__ws_serving:
r, w, e = select.select( r, w, e = select.select(
@ -349,6 +360,9 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
"""CGIHTTPRequestHandler specialized for WebSocket.""" """CGIHTTPRequestHandler specialized for WebSocket."""
# Use httplib.HTTPMessage instead of mimetools.Message.
MessageClass = httplib.HTTPMessage
def setup(self): def setup(self):
"""Override SocketServer.StreamRequestHandler.setup to wrap rfile """Override SocketServer.StreamRequestHandler.setup to wrap rfile
with MemorizingFile. with MemorizingFile.
@ -370,6 +384,8 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
max_memorized_lines=_MAX_MEMORIZED_LINES) max_memorized_lines=_MAX_MEMORIZED_LINES)
def __init__(self, request, client_address, server): def __init__(self, request, client_address, server):
self._logger = util.get_class_logger(self)
self._options = server.websocket_server_options self._options = server.websocket_server_options
# Overrides CGIHTTPServerRequestHandler.cgi_directories. # Overrides CGIHTTPServerRequestHandler.cgi_directories.
@ -378,10 +394,6 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
if self._options.is_executable_method is not None: if self._options.is_executable_method is not None:
self.is_executable = self._options.is_executable_method self.is_executable = self._options.is_executable_method
self._request = _StandaloneRequest(self, self._options.use_tls)
_print_warnings_if_any(self._options.dispatcher)
# This actually calls BaseRequestHandler.__init__. # This actually calls BaseRequestHandler.__init__.
CGIHTTPServer.CGIHTTPRequestHandler.__init__( CGIHTTPServer.CGIHTTPRequestHandler.__init__(
self, request, client_address, server) self, request, client_address, server)
@ -406,79 +418,87 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
return False return False
host, port, resource = http_header_util.parse_uri(self.path) host, port, resource = http_header_util.parse_uri(self.path)
if resource is None: if resource is None:
logging.info('mod_pywebsocket: invalid uri %r' % self.path) self._logger.info('Invalid URI: %r', self.path)
self._logger.info('Fallback to CGIHTTPRequestHandler')
return True return True
server_options = self.server.websocket_server_options server_options = self.server.websocket_server_options
if host is not None: if host is not None:
validation_host = server_options.validation_host validation_host = server_options.validation_host
if validation_host is not None and host != validation_host: if validation_host is not None and host != validation_host:
logging.info('mod_pywebsocket: invalid host %r ' self._logger.info('Invalid host: %r (expected: %r)',
'(expected: %r)' % (host, validation_host)) host,
validation_host)
self._logger.info('Fallback to CGIHTTPRequestHandler')
return True return True
if port is not None: if port is not None:
validation_port = server_options.validation_port validation_port = server_options.validation_port
if validation_port is not None and port != validation_port: if validation_port is not None and port != validation_port:
logging.info('mod_pywebsocket: invalid port %r ' self._logger.info('Invalid port: %r (expected: %r)',
'(expected: %r)' % (port, validation_port)) port,
validation_port)
self._logger.info('Fallback to CGIHTTPRequestHandler')
return True return True
self.path = resource self.path = resource
request = _StandaloneRequest(self, self._options.use_tls)
try: try:
# Fallback to default http handler for request paths for which # Fallback to default http handler for request paths for which
# we don't have request handlers. # we don't have request handlers.
if not self._options.dispatcher.get_handler_suite(self.path): if not self._options.dispatcher.get_handler_suite(self.path):
logging.info('No handlers for request: %s' % self.path) self._logger.info('No handler for resource: %r',
self.path)
self._logger.info('Fallback to CGIHTTPRequestHandler')
return True return True
except dispatch.DispatchException, e:
self._logger.info('%s', e)
self.send_error(e.status)
return False
# If any Exceptions without except clause setup (including
# DispatchException) is raised below this point, it will be caught
# and logged by WebSocketServer.
try:
try: try:
handshake.do_handshake( handshake.do_handshake(
self._request, request,
self._options.dispatcher, self._options.dispatcher,
allowDraft75=self._options.allow_draft75, allowDraft75=self._options.allow_draft75,
strict=self._options.strict) strict=self._options.strict)
except handshake.AbortedByUserException, e: except handshake.VersionException, e:
logging.info('mod_pywebsocket: %s' % e) self._logger.info('%s', e)
self.send_response(common.HTTP_STATUS_BAD_REQUEST)
self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER,
e.supported_versions)
self.end_headers()
return False return False
try: except handshake.HandshakeException, e:
self._request._dispatcher = self._options.dispatcher # Handshake for ws(s) failed.
self._options.dispatcher.transfer_data(self._request) self._logger.info('%s', e)
except dispatch.DispatchException, e: self.send_error(e.status)
logging.warning('mod_pywebsocket: %s' % e)
return False return False
except handshake.AbortedByUserException, e:
logging.info('mod_pywebsocket: %s' % e) request._dispatcher = self._options.dispatcher
except Exception, e: self._options.dispatcher.transfer_data(request)
# Catch exception in transfer_data. except handshake.AbortedByUserException, e:
# In this case, handshake has been successful, so just log self._logger.info('%s', e)
# the exception and return False.
logging.info('mod_pywebsocket: %s' % e)
logging.info(
'mod_pywebsocket: %s' % util.get_stack_trace())
except dispatch.DispatchException, e:
logging.warning('mod_pywebsocket: %s' % e)
self.send_error(e.status)
except handshake.HandshakeException, e:
# Handshake for ws(s) failed. Assume http(s).
logging.info('mod_pywebsocket: %s' % e)
self.send_error(e.status)
except Exception, e:
logging.warning('mod_pywebsocket: %s' % e)
logging.warning('mod_pywebsocket: %s' % util.get_stack_trace())
return False return False
def log_request(self, code='-', size='-'): def log_request(self, code='-', size='-'):
"""Override BaseHTTPServer.log_request.""" """Override BaseHTTPServer.log_request."""
logging.info('"%s" %s %s', self._logger.info('"%s" %s %s',
self.requestline, str(code), str(size)) self.requestline, str(code), str(size))
def log_error(self, *args): def log_error(self, *args):
"""Override BaseHTTPServer.log_error.""" """Override BaseHTTPServer.log_error."""
# Despite the name, this method is for warnings than for errors. # Despite the name, this method is for warnings than for errors.
# For example, HTTP status code is logged by this method. # For example, HTTP status code is logged by this method.
logging.warning('%s - %s' % self._logger.warning('%s - %s',
(self.address_string(), (args[0] % args[1:]))) self.address_string(),
args[0] % args[1:])
def is_cgi(self): def is_cgi(self):
"""Test whether self.path corresponds to a CGI script. """Test whether self.path corresponds to a CGI script.
@ -544,8 +564,9 @@ def _alias_handlers(dispatcher, websock_handlers_map_file):
fp.close() fp.close()
def _main(): def _build_option_parser():
parser = optparse.OptionParser() parser = optparse.OptionParser()
parser.add_option('-H', '--server-host', '--server_host', parser.add_option('-H', '--server-host', '--server_host',
dest='server_host', dest='server_host',
default='', default='',
@ -576,6 +597,13 @@ def _main():
default=None, default=None,
help=('WebSocket handlers scan directory. ' help=('WebSocket handlers scan directory. '
'Must be a directory under websock_handlers.')) 'Must be a directory under websock_handlers.'))
parser.add_option('--allow-handlers-outside-root-dir',
'--allow_handlers_outside_root_dir',
dest='allow_handlers_outside_root_dir',
action='store_true',
default=False,
help=('Scans WebSocket handlers even if their canonical '
'path is not under websock_handlers.'))
parser.add_option('-d', '--document-root', '--document_root', parser.add_option('-d', '--document-root', '--document_root',
dest='document_root', default='.', dest='document_root', default='.',
help='Document root directory.') help='Document root directory.')
@ -599,6 +627,15 @@ def _main():
choices=['debug', 'info', 'warning', 'warn', 'error', choices=['debug', 'info', 'warning', 'warn', 'error',
'critical'], 'critical'],
help='Log level.') help='Log level.')
parser.add_option('--thread-monitor-interval-in-sec',
'--thread_monitor_interval_in_sec',
dest='thread_monitor_interval_in_sec',
type='int', default=-1,
help=('If positive integer is specified, run a thread '
'monitor to show the status of server threads '
'periodically in the specified inteval in '
'second. If non-positive integer is specified, '
'disable the thread monitor.'))
parser.add_option('--log-max', '--log_max', dest='log_max', type='int', parser.add_option('--log-max', '--log_max', dest='log_max', type='int',
default=_DEFAULT_LOG_MAX_BYTES, default=_DEFAULT_LOG_MAX_BYTES,
help='Log maximum bytes') help='Log maximum bytes')
@ -613,7 +650,39 @@ def _main():
parser.add_option('-q', '--queue', dest='request_queue_size', type='int', parser.add_option('-q', '--queue', dest='request_queue_size', type='int',
default=_DEFAULT_REQUEST_QUEUE_SIZE, default=_DEFAULT_REQUEST_QUEUE_SIZE,
help='request queue size') help='request queue size')
options = parser.parse_args()[0]
return parser
class ThreadMonitor(threading.Thread):
daemon = True
def __init__(self, interval_in_sec):
threading.Thread.__init__(self, name='ThreadMonitor')
self._logger = util.get_class_logger(self)
self._interval_in_sec = interval_in_sec
def run(self):
while True:
thread_name_list = []
for thread in threading.enumerate():
thread_name_list.append(thread.name)
self._logger.info(
"%d active threads: %s",
threading.active_count(),
', '.join(thread_name_list))
time.sleep(self._interval_in_sec)
def _main(args=None):
parser = _build_option_parser()
options, args = parser.parse_args(args=args)
if args:
logging.critical('Unrecognized positional arguments: %r', args)
sys.exit(1)
os.chdir(options.document_root) os.chdir(options.document_root)
@ -653,14 +722,24 @@ def _main():
options.scan_dir = options.websock_handlers options.scan_dir = options.websock_handlers
try: try:
if options.thread_monitor_interval_in_sec > 0:
# Run a thread monitor to show the status of server threads for
# debugging.
ThreadMonitor(options.thread_monitor_interval_in_sec).start()
# Share a Dispatcher among request handlers to save time for # Share a Dispatcher among request handlers to save time for
# instantiation. Dispatcher can be shared because it is thread-safe. # instantiation. Dispatcher can be shared because it is thread-safe.
options.dispatcher = dispatch.Dispatcher(options.websock_handlers, options.dispatcher = dispatch.Dispatcher(
options.scan_dir) options.websock_handlers,
options.scan_dir,
options.allow_handlers_outside_root_dir)
if options.websock_handlers_map_file: if options.websock_handlers_map_file:
_alias_handlers(options.dispatcher, _alias_handlers(options.dispatcher,
options.websock_handlers_map_file) options.websock_handlers_map_file)
_print_warnings_if_any(options.dispatcher) warnings = options.dispatcher.source_warnings()
if warnings:
for warning in warnings:
logging.warning('mod_pywebsocket: %s' % warning)
server = WebSocketServer(options) server = WebSocketServer(options)
server.serve_forever() server.serve_forever()