asd
This commit is contained in:
@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn( # deprecated in 14.0 - 2024-11-09
|
||||
"websockets.legacy is deprecated; "
|
||||
"see https://websockets.readthedocs.io/en/stable/howto/upgrade.html "
|
||||
"for upgrade instructions",
|
||||
DeprecationWarning,
|
||||
)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
190
venv/lib/python3.12/site-packages/websockets/legacy/auth.py
Normal file
190
venv/lib/python3.12/site-packages/websockets/legacy/auth.py
Normal file
@ -0,0 +1,190 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hmac
|
||||
import http
|
||||
from collections.abc import Awaitable, Iterable
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from ..datastructures import Headers
|
||||
from ..exceptions import InvalidHeader
|
||||
from ..headers import build_www_authenticate_basic, parse_authorization_basic
|
||||
from .server import HTTPResponse, WebSocketServerProtocol
|
||||
|
||||
|
||||
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
|
||||
|
||||
Credentials = tuple[str, str]
|
||||
|
||||
|
||||
def is_credentials(value: Any) -> bool:
|
||||
try:
|
||||
username, password = value
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
else:
|
||||
return isinstance(username, str) and isinstance(password, str)
|
||||
|
||||
|
||||
class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
|
||||
"""
|
||||
WebSocket server protocol that enforces HTTP Basic Auth.
|
||||
|
||||
"""
|
||||
|
||||
realm: str = ""
|
||||
"""
|
||||
Scope of protection.
|
||||
|
||||
If provided, it should contain only ASCII characters because the
|
||||
encoding of non-ASCII characters is undefined.
|
||||
"""
|
||||
|
||||
username: str | None = None
|
||||
"""Username of the authenticated user."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
realm: str | None = None,
|
||||
check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if realm is not None:
|
||||
self.realm = realm # shadow class attribute
|
||||
self._check_credentials = check_credentials
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def check_credentials(self, username: str, password: str) -> bool:
|
||||
"""
|
||||
Check whether credentials are authorized.
|
||||
|
||||
This coroutine may be overridden in a subclass, for example to
|
||||
authenticate against a database or an external service.
|
||||
|
||||
Args:
|
||||
username: HTTP Basic Auth username.
|
||||
password: HTTP Basic Auth password.
|
||||
|
||||
Returns:
|
||||
:obj:`True` if the handshake should continue;
|
||||
:obj:`False` if it should fail with an HTTP 401 error.
|
||||
|
||||
"""
|
||||
if self._check_credentials is not None:
|
||||
return await self._check_credentials(username, password)
|
||||
|
||||
return False
|
||||
|
||||
async def process_request(
|
||||
self,
|
||||
path: str,
|
||||
request_headers: Headers,
|
||||
) -> HTTPResponse | None:
|
||||
"""
|
||||
Check HTTP Basic Auth and return an HTTP 401 response if needed.
|
||||
|
||||
"""
|
||||
try:
|
||||
authorization = request_headers["Authorization"]
|
||||
except KeyError:
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Missing credentials\n",
|
||||
)
|
||||
|
||||
try:
|
||||
username, password = parse_authorization_basic(authorization)
|
||||
except InvalidHeader:
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Unsupported credentials\n",
|
||||
)
|
||||
|
||||
if not await self.check_credentials(username, password):
|
||||
return (
|
||||
http.HTTPStatus.UNAUTHORIZED,
|
||||
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
|
||||
b"Invalid credentials\n",
|
||||
)
|
||||
|
||||
self.username = username
|
||||
|
||||
return await super().process_request(path, request_headers)
|
||||
|
||||
|
||||
def basic_auth_protocol_factory(
|
||||
realm: str | None = None,
|
||||
credentials: Credentials | Iterable[Credentials] | None = None,
|
||||
check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
|
||||
create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None,
|
||||
) -> Callable[..., BasicAuthWebSocketServerProtocol]:
|
||||
"""
|
||||
Protocol factory that enforces HTTP Basic Auth.
|
||||
|
||||
:func:`basic_auth_protocol_factory` is designed to integrate with
|
||||
:func:`~websockets.legacy.server.serve` like this::
|
||||
|
||||
serve(
|
||||
...,
|
||||
create_protocol=basic_auth_protocol_factory(
|
||||
realm="my dev server",
|
||||
credentials=("hello", "iloveyou"),
|
||||
)
|
||||
)
|
||||
|
||||
Args:
|
||||
realm: Scope of protection. It should contain only ASCII characters
|
||||
because the encoding of non-ASCII characters is undefined.
|
||||
Refer to section 2.2 of :rfc:`7235` for details.
|
||||
credentials: Hard coded authorized credentials. It can be a
|
||||
``(username, password)`` pair or a list of such pairs.
|
||||
check_credentials: Coroutine that verifies credentials.
|
||||
It receives ``username`` and ``password`` arguments
|
||||
and returns a :class:`bool`. One of ``credentials`` or
|
||||
``check_credentials`` must be provided but not both.
|
||||
create_protocol: Factory that creates the protocol. By default, this
|
||||
is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
|
||||
by a subclass.
|
||||
Raises:
|
||||
TypeError: If the ``credentials`` or ``check_credentials`` argument is
|
||||
wrong.
|
||||
|
||||
"""
|
||||
if (credentials is None) == (check_credentials is None):
|
||||
raise TypeError("provide either credentials or check_credentials")
|
||||
|
||||
if credentials is not None:
|
||||
if is_credentials(credentials):
|
||||
credentials_list = [cast(Credentials, credentials)]
|
||||
elif isinstance(credentials, Iterable):
|
||||
credentials_list = list(cast(Iterable[Credentials], credentials))
|
||||
if not all(is_credentials(item) for item in credentials_list):
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
else:
|
||||
raise TypeError(f"invalid credentials argument: {credentials}")
|
||||
|
||||
credentials_dict = dict(credentials_list)
|
||||
|
||||
async def check_credentials(username: str, password: str) -> bool:
|
||||
try:
|
||||
expected_password = credentials_dict[username]
|
||||
except KeyError:
|
||||
return False
|
||||
return hmac.compare_digest(expected_password, password)
|
||||
|
||||
if create_protocol is None:
|
||||
create_protocol = BasicAuthWebSocketServerProtocol
|
||||
|
||||
# Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] |
|
||||
# Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc]
|
||||
create_protocol = cast(
|
||||
Callable[..., BasicAuthWebSocketServerProtocol], create_protocol
|
||||
)
|
||||
return functools.partial(
|
||||
create_protocol,
|
||||
realm=realm,
|
||||
check_credentials=check_credentials,
|
||||
)
|
704
venv/lib/python3.12/site-packages/websockets/legacy/client.py
Normal file
704
venv/lib/python3.12/site-packages/websockets/legacy/client.py
Normal file
@ -0,0 +1,704 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
import urllib.parse
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Generator, Sequence
|
||||
from types import TracebackType
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from ..asyncio.compatibility import asyncio_timeout
|
||||
from ..datastructures import Headers, HeadersLike
|
||||
from ..exceptions import (
|
||||
InvalidHeader,
|
||||
InvalidHeaderValue,
|
||||
NegotiationError,
|
||||
SecurityError,
|
||||
)
|
||||
from ..extensions import ClientExtensionFactory, Extension
|
||||
from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
||||
from ..headers import (
|
||||
build_authorization_basic,
|
||||
build_extension,
|
||||
build_host,
|
||||
build_subprotocol,
|
||||
parse_extension,
|
||||
parse_subprotocol,
|
||||
validate_subprotocols,
|
||||
)
|
||||
from ..http11 import USER_AGENT
|
||||
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
|
||||
from ..uri import WebSocketURI, parse_uri
|
||||
from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake
|
||||
from .handshake import build_request, check_response
|
||||
from .http import read_response
|
||||
from .protocol import WebSocketCommonProtocol
|
||||
|
||||
|
||||
__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
|
||||
|
||||
|
||||
class WebSocketClientProtocol(WebSocketCommonProtocol):
|
||||
"""
|
||||
WebSocket client connection.
|
||||
|
||||
:class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send`
|
||||
coroutines for receiving and sending messages.
|
||||
|
||||
It supports asynchronous iteration to receive messages::
|
||||
|
||||
async for message in websocket:
|
||||
await process(message)
|
||||
|
||||
The iterator exits normally when the connection is closed with close code
|
||||
1000 (OK) or 1001 (going away) or without a close code. It raises
|
||||
a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
|
||||
is closed with any other code.
|
||||
|
||||
See :func:`connect` for the documentation of ``logger``, ``origin``,
|
||||
``extensions``, ``subprotocols``, ``extra_headers``, and
|
||||
``user_agent_header``.
|
||||
|
||||
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
||||
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
||||
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
||||
|
||||
"""
|
||||
|
||||
is_client = True
|
||||
side = "client"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
logger: LoggerLike | None = None,
|
||||
origin: Origin | None = None,
|
||||
extensions: Sequence[ClientExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
extra_headers: HeadersLike | None = None,
|
||||
user_agent_header: str | None = USER_AGENT,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
super().__init__(logger=logger, **kwargs)
|
||||
self.origin = origin
|
||||
self.available_extensions = extensions
|
||||
self.available_subprotocols = subprotocols
|
||||
self.extra_headers = extra_headers
|
||||
self.user_agent_header = user_agent_header
|
||||
|
||||
def write_http_request(self, path: str, headers: Headers) -> None:
|
||||
"""
|
||||
Write request line and headers to the HTTP request.
|
||||
|
||||
"""
|
||||
self.path = path
|
||||
self.request_headers = headers
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("> GET %s HTTP/1.1", path)
|
||||
for key, value in headers.raw_items():
|
||||
self.logger.debug("> %s: %s", key, value)
|
||||
|
||||
# Since the path and headers only contain ASCII characters,
|
||||
# we can keep this simple.
|
||||
request = f"GET {path} HTTP/1.1\r\n"
|
||||
request += str(headers)
|
||||
|
||||
self.transport.write(request.encode())
|
||||
|
||||
async def read_http_response(self) -> tuple[int, Headers]:
|
||||
"""
|
||||
Read status line and headers from the HTTP response.
|
||||
|
||||
If the response contains a body, it may be read from ``self.reader``
|
||||
after this coroutine returns.
|
||||
|
||||
Raises:
|
||||
InvalidMessage: If the HTTP message is malformed or isn't an
|
||||
HTTP/1.1 GET response.
|
||||
|
||||
"""
|
||||
try:
|
||||
status_code, reason, headers = await read_response(self.reader)
|
||||
except Exception as exc:
|
||||
raise InvalidMessage("did not receive a valid HTTP response") from exc
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("< HTTP/1.1 %d %s", status_code, reason)
|
||||
for key, value in headers.raw_items():
|
||||
self.logger.debug("< %s: %s", key, value)
|
||||
|
||||
self.response_headers = headers
|
||||
|
||||
return status_code, self.response_headers
|
||||
|
||||
@staticmethod
|
||||
def process_extensions(
|
||||
headers: Headers,
|
||||
available_extensions: Sequence[ClientExtensionFactory] | None,
|
||||
) -> list[Extension]:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Extensions HTTP response header.
|
||||
|
||||
Check that each extension is supported, as well as its parameters.
|
||||
|
||||
Return the list of accepted extensions.
|
||||
|
||||
Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
|
||||
connection.
|
||||
|
||||
:rfc:`6455` leaves the rules up to the specification of each
|
||||
:extension.
|
||||
|
||||
To provide this level of flexibility, for each extension accepted by
|
||||
the server, we check for a match with each extension available in the
|
||||
client configuration. If no match is found, an exception is raised.
|
||||
|
||||
If several variants of the same extension are accepted by the server,
|
||||
it may be configured several times, which won't make sense in general.
|
||||
Extensions must implement their own requirements. For this purpose,
|
||||
the list of previously accepted extensions is provided.
|
||||
|
||||
Other requirements, for example related to mandatory extensions or the
|
||||
order of extensions, may be implemented by overriding this method.
|
||||
|
||||
"""
|
||||
accepted_extensions: list[Extension] = []
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Extensions")
|
||||
|
||||
if header_values:
|
||||
if available_extensions is None:
|
||||
raise NegotiationError("no extensions supported")
|
||||
|
||||
parsed_header_values: list[ExtensionHeader] = sum(
|
||||
[parse_extension(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
for name, response_params in parsed_header_values:
|
||||
for extension_factory in available_extensions:
|
||||
# Skip non-matching extensions based on their name.
|
||||
if extension_factory.name != name:
|
||||
continue
|
||||
|
||||
# Skip non-matching extensions based on their params.
|
||||
try:
|
||||
extension = extension_factory.process_response_params(
|
||||
response_params, accepted_extensions
|
||||
)
|
||||
except NegotiationError:
|
||||
continue
|
||||
|
||||
# Add matching extension to the final list.
|
||||
accepted_extensions.append(extension)
|
||||
|
||||
# Break out of the loop once we have a match.
|
||||
break
|
||||
|
||||
# If we didn't break from the loop, no extension in our list
|
||||
# matched what the server sent. Fail the connection.
|
||||
else:
|
||||
raise NegotiationError(
|
||||
f"Unsupported extension: "
|
||||
f"name = {name}, params = {response_params}"
|
||||
)
|
||||
|
||||
return accepted_extensions
|
||||
|
||||
@staticmethod
|
||||
def process_subprotocol(
|
||||
headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
|
||||
) -> Subprotocol | None:
|
||||
"""
|
||||
Handle the Sec-WebSocket-Protocol HTTP response header.
|
||||
|
||||
Check that it contains exactly one supported subprotocol.
|
||||
|
||||
Return the selected subprotocol.
|
||||
|
||||
"""
|
||||
subprotocol: Subprotocol | None = None
|
||||
|
||||
header_values = headers.get_all("Sec-WebSocket-Protocol")
|
||||
|
||||
if header_values:
|
||||
if available_subprotocols is None:
|
||||
raise NegotiationError("no subprotocols supported")
|
||||
|
||||
parsed_header_values: Sequence[Subprotocol] = sum(
|
||||
[parse_subprotocol(header_value) for header_value in header_values], []
|
||||
)
|
||||
|
||||
if len(parsed_header_values) > 1:
|
||||
raise InvalidHeaderValue(
|
||||
"Sec-WebSocket-Protocol",
|
||||
f"multiple values: {', '.join(parsed_header_values)}",
|
||||
)
|
||||
|
||||
subprotocol = parsed_header_values[0]
|
||||
|
||||
if subprotocol not in available_subprotocols:
|
||||
raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
|
||||
|
||||
return subprotocol
|
||||
|
||||
async def handshake(
|
||||
self,
|
||||
wsuri: WebSocketURI,
|
||||
origin: Origin | None = None,
|
||||
available_extensions: Sequence[ClientExtensionFactory] | None = None,
|
||||
available_subprotocols: Sequence[Subprotocol] | None = None,
|
||||
extra_headers: HeadersLike | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Perform the client side of the opening handshake.
|
||||
|
||||
Args:
|
||||
wsuri: URI of the WebSocket server.
|
||||
origin: Value of the ``Origin`` header.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
extra_headers: Arbitrary HTTP headers to add to the handshake request.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake fails.
|
||||
|
||||
"""
|
||||
request_headers = Headers()
|
||||
|
||||
request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure)
|
||||
|
||||
if wsuri.user_info:
|
||||
request_headers["Authorization"] = build_authorization_basic(
|
||||
*wsuri.user_info
|
||||
)
|
||||
|
||||
if origin is not None:
|
||||
request_headers["Origin"] = origin
|
||||
|
||||
key = build_request(request_headers)
|
||||
|
||||
if available_extensions is not None:
|
||||
extensions_header = build_extension(
|
||||
[
|
||||
(extension_factory.name, extension_factory.get_request_params())
|
||||
for extension_factory in available_extensions
|
||||
]
|
||||
)
|
||||
request_headers["Sec-WebSocket-Extensions"] = extensions_header
|
||||
|
||||
if available_subprotocols is not None:
|
||||
protocol_header = build_subprotocol(available_subprotocols)
|
||||
request_headers["Sec-WebSocket-Protocol"] = protocol_header
|
||||
|
||||
if self.extra_headers is not None:
|
||||
request_headers.update(self.extra_headers)
|
||||
|
||||
if self.user_agent_header:
|
||||
request_headers.setdefault("User-Agent", self.user_agent_header)
|
||||
|
||||
self.write_http_request(wsuri.resource_name, request_headers)
|
||||
|
||||
status_code, response_headers = await self.read_http_response()
|
||||
if status_code in (301, 302, 303, 307, 308):
|
||||
if "Location" not in response_headers:
|
||||
raise InvalidHeader("Location")
|
||||
raise RedirectHandshake(response_headers["Location"])
|
||||
elif status_code != 101:
|
||||
raise InvalidStatusCode(status_code, response_headers)
|
||||
|
||||
check_response(response_headers, key)
|
||||
|
||||
self.extensions = self.process_extensions(
|
||||
response_headers, available_extensions
|
||||
)
|
||||
|
||||
self.subprotocol = self.process_subprotocol(
|
||||
response_headers, available_subprotocols
|
||||
)
|
||||
|
||||
self.connection_open()
|
||||
|
||||
|
||||
class Connect:
|
||||
"""
|
||||
Connect to the WebSocket server at ``uri``.
|
||||
|
||||
Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
|
||||
can then be used to send and receive messages.
|
||||
|
||||
:func:`connect` can be used as a asynchronous context manager::
|
||||
|
||||
async with connect(...) as websocket:
|
||||
...
|
||||
|
||||
The connection is closed automatically when exiting the context.
|
||||
|
||||
:func:`connect` can be used as an infinite asynchronous iterator to
|
||||
reconnect automatically on errors::
|
||||
|
||||
async for websocket in connect(...):
|
||||
try:
|
||||
...
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
continue
|
||||
|
||||
The connection is closed automatically after each iteration of the loop.
|
||||
|
||||
If an error occurs while establishing the connection, :func:`connect`
|
||||
retries with exponential backoff. The backoff delay starts at three
|
||||
seconds and increases up to one minute.
|
||||
|
||||
If an error occurs in the body of the loop, you can handle the exception
|
||||
and :func:`connect` will reconnect with the next iteration; or you can
|
||||
let the exception bubble up and break out of the loop. This lets you
|
||||
decide which errors trigger a reconnection and which errors are fatal.
|
||||
|
||||
Args:
|
||||
uri: URI of the WebSocket server.
|
||||
create_protocol: Factory for the :class:`asyncio.Protocol` managing
|
||||
the connection. It defaults to :class:`WebSocketClientProtocol`.
|
||||
Set it to a wrapper or a subclass to customize connection handling.
|
||||
logger: Logger for this client.
|
||||
It defaults to ``logging.getLogger("websockets.client")``.
|
||||
See the :doc:`logging guide <../../topics/logging>` for details.
|
||||
compression: The "permessage-deflate" extension is enabled by default.
|
||||
Set ``compression`` to :obj:`None` to disable it. See the
|
||||
:doc:`compression guide <../../topics/compression>` for details.
|
||||
origin: Value of the ``Origin`` header, for servers that require it.
|
||||
extensions: List of supported extensions, in order in which they
|
||||
should be negotiated and run.
|
||||
subprotocols: List of supported subprotocols, in order of decreasing
|
||||
preference.
|
||||
extra_headers: Arbitrary HTTP headers to add to the handshake request.
|
||||
user_agent_header: Value of the ``User-Agent`` request header.
|
||||
It defaults to ``"Python/x.y.z websockets/X.Y"``.
|
||||
Setting it to :obj:`None` removes the header.
|
||||
open_timeout: Timeout for opening the connection in seconds.
|
||||
:obj:`None` disables the timeout.
|
||||
|
||||
See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
|
||||
documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
|
||||
``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
|
||||
|
||||
Any other keyword arguments are passed the event loop's
|
||||
:meth:`~asyncio.loop.create_connection` method.
|
||||
|
||||
For example:
|
||||
|
||||
* You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS
|
||||
settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't
|
||||
provided, a TLS context is created
|
||||
with :func:`~ssl.create_default_context`.
|
||||
|
||||
* You can set ``host`` and ``port`` to connect to a different host and
|
||||
port from those found in ``uri``. This only changes the destination of
|
||||
the TCP connection. The host name from ``uri`` is still used in the TLS
|
||||
handshake for secure connections and in the ``Host`` header.
|
||||
|
||||
Raises:
|
||||
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
|
||||
OSError: If the TCP connection fails.
|
||||
InvalidHandshake: If the opening handshake fails.
|
||||
~asyncio.TimeoutError: If the opening handshake times out.
|
||||
|
||||
"""
|
||||
|
||||
MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
create_protocol: Callable[..., WebSocketClientProtocol] | None = None,
|
||||
logger: LoggerLike | None = None,
|
||||
compression: str | None = "deflate",
|
||||
origin: Origin | None = None,
|
||||
extensions: Sequence[ClientExtensionFactory] | None = None,
|
||||
subprotocols: Sequence[Subprotocol] | None = None,
|
||||
extra_headers: HeadersLike | None = None,
|
||||
user_agent_header: str | None = USER_AGENT,
|
||||
open_timeout: float | None = 10,
|
||||
ping_interval: float | None = 20,
|
||||
ping_timeout: float | None = 20,
|
||||
close_timeout: float | None = None,
|
||||
max_size: int | None = 2**20,
|
||||
max_queue: int | None = 2**5,
|
||||
read_limit: int = 2**16,
|
||||
write_limit: int = 2**16,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Backwards compatibility: close_timeout used to be called timeout.
|
||||
timeout: float | None = kwargs.pop("timeout", None)
|
||||
if timeout is None:
|
||||
timeout = 10
|
||||
else:
|
||||
warnings.warn("rename timeout to close_timeout", DeprecationWarning)
|
||||
# If both are specified, timeout is ignored.
|
||||
if close_timeout is None:
|
||||
close_timeout = timeout
|
||||
|
||||
# Backwards compatibility: create_protocol used to be called klass.
|
||||
klass: type[WebSocketClientProtocol] | None = kwargs.pop("klass", None)
|
||||
if klass is None:
|
||||
klass = WebSocketClientProtocol
|
||||
else:
|
||||
warnings.warn("rename klass to create_protocol", DeprecationWarning)
|
||||
# If both are specified, klass is ignored.
|
||||
if create_protocol is None:
|
||||
create_protocol = klass
|
||||
|
||||
# Backwards compatibility: recv() used to return None on closed connections
|
||||
legacy_recv: bool = kwargs.pop("legacy_recv", False)
|
||||
|
||||
# Backwards compatibility: the loop parameter used to be supported.
|
||||
_loop: asyncio.AbstractEventLoop | None = kwargs.pop("loop", None)
|
||||
if _loop is None:
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
loop = _loop
|
||||
warnings.warn("remove loop argument", DeprecationWarning)
|
||||
|
||||
wsuri = parse_uri(uri)
|
||||
if wsuri.secure:
|
||||
kwargs.setdefault("ssl", True)
|
||||
elif kwargs.get("ssl") is not None:
|
||||
raise ValueError(
|
||||
"connect() received a ssl argument for a ws:// URI, "
|
||||
"use a wss:// URI to enable TLS"
|
||||
)
|
||||
|
||||
if compression == "deflate":
|
||||
extensions = enable_client_permessage_deflate(extensions)
|
||||
elif compression is not None:
|
||||
raise ValueError(f"unsupported compression: {compression}")
|
||||
|
||||
if subprotocols is not None:
|
||||
validate_subprotocols(subprotocols)
|
||||
|
||||
# Help mypy and avoid this error: "type[WebSocketClientProtocol] |
|
||||
# Callable[..., WebSocketClientProtocol]" not callable [misc]
|
||||
create_protocol = cast(Callable[..., WebSocketClientProtocol], create_protocol)
|
||||
factory = functools.partial(
|
||||
create_protocol,
|
||||
logger=logger,
|
||||
origin=origin,
|
||||
extensions=extensions,
|
||||
subprotocols=subprotocols,
|
||||
extra_headers=extra_headers,
|
||||
user_agent_header=user_agent_header,
|
||||
ping_interval=ping_interval,
|
||||
ping_timeout=ping_timeout,
|
||||
close_timeout=close_timeout,
|
||||
max_size=max_size,
|
||||
max_queue=max_queue,
|
||||
read_limit=read_limit,
|
||||
write_limit=write_limit,
|
||||
host=wsuri.host,
|
||||
port=wsuri.port,
|
||||
secure=wsuri.secure,
|
||||
legacy_recv=legacy_recv,
|
||||
loop=_loop,
|
||||
)
|
||||
|
||||
if kwargs.pop("unix", False):
|
||||
path: str | None = kwargs.pop("path", None)
|
||||
create_connection = functools.partial(
|
||||
loop.create_unix_connection, factory, path, **kwargs
|
||||
)
|
||||
else:
|
||||
host: str | None
|
||||
port: int | None
|
||||
if kwargs.get("sock") is None:
|
||||
host, port = wsuri.host, wsuri.port
|
||||
else:
|
||||
# If sock is given, host and port shouldn't be specified.
|
||||
host, port = None, None
|
||||
if kwargs.get("ssl"):
|
||||
kwargs.setdefault("server_hostname", wsuri.host)
|
||||
# If host and port are given, override values from the URI.
|
||||
host = kwargs.pop("host", host)
|
||||
port = kwargs.pop("port", port)
|
||||
create_connection = functools.partial(
|
||||
loop.create_connection, factory, host, port, **kwargs
|
||||
)
|
||||
|
||||
self.open_timeout = open_timeout
|
||||
if logger is None:
|
||||
logger = logging.getLogger("websockets.client")
|
||||
self.logger = logger
|
||||
|
||||
# This is a coroutine function.
|
||||
self._create_connection = create_connection
|
||||
self._uri = uri
|
||||
self._wsuri = wsuri
|
||||
|
||||
def handle_redirect(self, uri: str) -> None:
|
||||
# Update the state of this instance to connect to a new URI.
|
||||
old_uri = self._uri
|
||||
old_wsuri = self._wsuri
|
||||
new_uri = urllib.parse.urljoin(old_uri, uri)
|
||||
new_wsuri = parse_uri(new_uri)
|
||||
|
||||
# Forbid TLS downgrade.
|
||||
if old_wsuri.secure and not new_wsuri.secure:
|
||||
raise SecurityError("redirect from WSS to WS")
|
||||
|
||||
same_origin = (
|
||||
old_wsuri.secure == new_wsuri.secure
|
||||
and old_wsuri.host == new_wsuri.host
|
||||
and old_wsuri.port == new_wsuri.port
|
||||
)
|
||||
|
||||
# Rewrite secure, host, and port for cross-origin redirects.
|
||||
# This preserves connection overrides with the host and port
|
||||
# arguments if the redirect points to the same host and port.
|
||||
if not same_origin:
|
||||
factory = self._create_connection.args[0]
|
||||
# Support TLS upgrade.
|
||||
if not old_wsuri.secure and new_wsuri.secure:
|
||||
factory.keywords["secure"] = True
|
||||
self._create_connection.keywords.setdefault("ssl", True)
|
||||
# Replace secure, host, and port arguments of the protocol factory.
|
||||
factory = functools.partial(
|
||||
factory.func,
|
||||
*factory.args,
|
||||
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
|
||||
)
|
||||
# Replace secure, host, and port arguments of create_connection.
|
||||
self._create_connection = functools.partial(
|
||||
self._create_connection.func,
|
||||
*(factory, new_wsuri.host, new_wsuri.port),
|
||||
**self._create_connection.keywords,
|
||||
)
|
||||
|
||||
# Set the new WebSocket URI. This suffices for same-origin redirects.
|
||||
self._uri = new_uri
|
||||
self._wsuri = new_wsuri
|
||||
|
||||
# async for ... in connect(...):
|
||||
|
||||
BACKOFF_INITIAL = float(os.environ.get("WEBSOCKETS_BACKOFF_INITIAL_DELAY", "5"))
|
||||
BACKOFF_MIN = float(os.environ.get("WEBSOCKETS_BACKOFF_MIN_DELAY", "3.1"))
|
||||
BACKOFF_MAX = float(os.environ.get("WEBSOCKETS_BACKOFF_MAX_DELAY", "90.0"))
|
||||
BACKOFF_FACTOR = float(os.environ.get("WEBSOCKETS_BACKOFF_FACTOR", "1.618"))
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]:
|
||||
backoff_delay = self.BACKOFF_MIN / self.BACKOFF_FACTOR
|
||||
while True:
|
||||
try:
|
||||
async with self as protocol:
|
||||
yield protocol
|
||||
except Exception as exc:
|
||||
# Add a random initial delay between 0 and 5 seconds.
|
||||
# See 7.2.3. Recovering from Abnormal Closure in RFC 6455.
|
||||
if backoff_delay == self.BACKOFF_MIN:
|
||||
initial_delay = random.random() * self.BACKOFF_INITIAL
|
||||
self.logger.info(
|
||||
"connect failed; reconnecting in %.1f seconds: %s",
|
||||
initial_delay,
|
||||
# Remove first argument when dropping Python 3.9.
|
||||
traceback.format_exception_only(type(exc), exc)[0].strip(),
|
||||
)
|
||||
await asyncio.sleep(initial_delay)
|
||||
else:
|
||||
self.logger.info(
|
||||
"connect failed again; retrying in %d seconds: %s",
|
||||
int(backoff_delay),
|
||||
# Remove first argument when dropping Python 3.9.
|
||||
traceback.format_exception_only(type(exc), exc)[0].strip(),
|
||||
)
|
||||
await asyncio.sleep(int(backoff_delay))
|
||||
# Increase delay with truncated exponential backoff.
|
||||
backoff_delay = backoff_delay * self.BACKOFF_FACTOR
|
||||
backoff_delay = min(backoff_delay, self.BACKOFF_MAX)
|
||||
continue
|
||||
else:
|
||||
# Connection succeeded - reset backoff delay
|
||||
backoff_delay = self.BACKOFF_MIN
|
||||
|
||||
# async with connect(...) as ...:
|
||||
|
||||
async def __aenter__(self) -> WebSocketClientProtocol:
|
||||
return await self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
await self.protocol.close()
|
||||
|
||||
# ... = await connect(...)
|
||||
|
||||
def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
|
||||
# Create a suitable iterator by calling __await__ on a coroutine.
|
||||
return self.__await_impl__().__await__()
|
||||
|
||||
async def __await_impl__(self) -> WebSocketClientProtocol:
|
||||
async with asyncio_timeout(self.open_timeout):
|
||||
for _redirects in range(self.MAX_REDIRECTS_ALLOWED):
|
||||
_transport, protocol = await self._create_connection()
|
||||
try:
|
||||
await protocol.handshake(
|
||||
self._wsuri,
|
||||
origin=protocol.origin,
|
||||
available_extensions=protocol.available_extensions,
|
||||
available_subprotocols=protocol.available_subprotocols,
|
||||
extra_headers=protocol.extra_headers,
|
||||
)
|
||||
except RedirectHandshake as exc:
|
||||
protocol.fail_connection()
|
||||
await protocol.wait_closed()
|
||||
self.handle_redirect(exc.uri)
|
||||
# Avoid leaking a connected socket when the handshake fails.
|
||||
except (Exception, asyncio.CancelledError):
|
||||
protocol.fail_connection()
|
||||
await protocol.wait_closed()
|
||||
raise
|
||||
else:
|
||||
self.protocol = protocol
|
||||
return protocol
|
||||
else:
|
||||
raise SecurityError("too many redirects")
|
||||
|
||||
# ... = yield from connect(...) - remove when dropping Python < 3.10
|
||||
|
||||
__iter__ = __await__
|
||||
|
||||
|
||||
connect = Connect
|
||||
|
||||
|
||||
def unix_connect(
|
||||
path: str | None = None,
|
||||
uri: str = "ws://localhost/",
|
||||
**kwargs: Any,
|
||||
) -> Connect:
|
||||
"""
|
||||
Similar to :func:`connect`, but for connecting to a Unix socket.
|
||||
|
||||
This function builds upon the event loop's
|
||||
:meth:`~asyncio.loop.create_unix_connection` method.
|
||||
|
||||
It is only available on Unix.
|
||||
|
||||
It's mainly useful for debugging servers listening on Unix sockets.
|
||||
|
||||
Args:
|
||||
path: File system path to the Unix socket.
|
||||
uri: URI of the WebSocket server; the host is used in the TLS
|
||||
handshake for secure connections and in the ``Host`` header.
|
||||
|
||||
"""
|
||||
return connect(uri=uri, path=path, unix=True, **kwargs)
|
@ -0,0 +1,78 @@
|
||||
import http
|
||||
|
||||
from .. import datastructures
|
||||
from ..exceptions import (
|
||||
InvalidHandshake,
|
||||
ProtocolError as WebSocketProtocolError, # noqa: F401
|
||||
)
|
||||
from ..typing import StatusLike
|
||||
|
||||
|
||||
class InvalidMessage(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake request or response is malformed.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class InvalidStatusCode(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake response status code is invalid.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, headers: datastructures.Headers) -> None:
|
||||
self.status_code = status_code
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"server rejected WebSocket connection: HTTP {self.status_code}"
|
||||
|
||||
|
||||
class AbortHandshake(InvalidHandshake):
|
||||
"""
|
||||
Raised to abort the handshake on purpose and return an HTTP response.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
The public API is
|
||||
:meth:`~websockets.legacy.server.WebSocketServerProtocol.process_request`.
|
||||
|
||||
Attributes:
|
||||
status (~http.HTTPStatus): HTTP status code.
|
||||
headers (Headers): HTTP response headers.
|
||||
body (bytes): HTTP response body.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status: StatusLike,
|
||||
headers: datastructures.HeadersLike,
|
||||
body: bytes = b"",
|
||||
) -> None:
|
||||
# If a user passes an int instead of an HTTPStatus, fix it automatically.
|
||||
self.status = http.HTTPStatus(status)
|
||||
self.headers = datastructures.Headers(headers)
|
||||
self.body = body
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"HTTP {self.status:d}, "
|
||||
f"{len(self.headers)} headers, "
|
||||
f"{len(self.body)} bytes"
|
||||
)
|
||||
|
||||
|
||||
class RedirectHandshake(InvalidHandshake):
|
||||
"""
|
||||
Raised when a handshake gets redirected.
|
||||
|
||||
This exception is an implementation detail.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, uri: str) -> None:
|
||||
self.uri = uri
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"redirect to {self.uri}"
|
225
venv/lib/python3.12/site-packages/websockets/legacy/framing.py
Normal file
225
venv/lib/python3.12/site-packages/websockets/legacy/framing.py
Normal file
@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from typing import Any, Callable, NamedTuple
|
||||
|
||||
from .. import extensions, frames
|
||||
from ..exceptions import PayloadTooBig, ProtocolError
|
||||
from ..frames import BytesLike
|
||||
from ..typing import Data
|
||||
|
||||
|
||||
try:
|
||||
from ..speedups import apply_mask
|
||||
except ImportError:
|
||||
from ..utils import apply_mask
|
||||
|
||||
|
||||
class Frame(NamedTuple):
|
||||
fin: bool
|
||||
opcode: frames.Opcode
|
||||
data: bytes
|
||||
rsv1: bool = False
|
||||
rsv2: bool = False
|
||||
rsv3: bool = False
|
||||
|
||||
@property
|
||||
def new_frame(self) -> frames.Frame:
|
||||
return frames.Frame(
|
||||
self.opcode,
|
||||
self.data,
|
||||
self.fin,
|
||||
self.rsv1,
|
||||
self.rsv2,
|
||||
self.rsv3,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.new_frame)
|
||||
|
||||
def check(self) -> None:
|
||||
return self.new_frame.check()
|
||||
|
||||
@classmethod
|
||||
async def read(
|
||||
cls,
|
||||
reader: Callable[[int], Awaitable[bytes]],
|
||||
*,
|
||||
mask: bool,
|
||||
max_size: int | None = None,
|
||||
extensions: Sequence[extensions.Extension] | None = None,
|
||||
) -> Frame:
|
||||
"""
|
||||
Read a WebSocket frame.
|
||||
|
||||
Args:
|
||||
reader: Coroutine that reads exactly the requested number of
|
||||
bytes, unless the end of file is reached.
|
||||
mask: Whether the frame should be masked i.e. whether the read
|
||||
happens on the server side.
|
||||
max_size: Maximum payload size in bytes.
|
||||
extensions: List of extensions, applied in reverse order.
|
||||
|
||||
Raises:
|
||||
PayloadTooBig: If the frame exceeds ``max_size``.
|
||||
ProtocolError: If the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
|
||||
# Read the header.
|
||||
data = await reader(2)
|
||||
head1, head2 = struct.unpack("!BB", data)
|
||||
|
||||
# While not Pythonic, this is marginally faster than calling bool().
|
||||
fin = True if head1 & 0b10000000 else False
|
||||
rsv1 = True if head1 & 0b01000000 else False
|
||||
rsv2 = True if head1 & 0b00100000 else False
|
||||
rsv3 = True if head1 & 0b00010000 else False
|
||||
|
||||
try:
|
||||
opcode = frames.Opcode(head1 & 0b00001111)
|
||||
except ValueError as exc:
|
||||
raise ProtocolError("invalid opcode") from exc
|
||||
|
||||
if (True if head2 & 0b10000000 else False) != mask:
|
||||
raise ProtocolError("incorrect masking")
|
||||
|
||||
length = head2 & 0b01111111
|
||||
if length == 126:
|
||||
data = await reader(2)
|
||||
(length,) = struct.unpack("!H", data)
|
||||
elif length == 127:
|
||||
data = await reader(8)
|
||||
(length,) = struct.unpack("!Q", data)
|
||||
if max_size is not None and length > max_size:
|
||||
raise PayloadTooBig(length, max_size)
|
||||
if mask:
|
||||
mask_bits = await reader(4)
|
||||
|
||||
# Read the data.
|
||||
data = await reader(length)
|
||||
if mask:
|
||||
data = apply_mask(data, mask_bits)
|
||||
|
||||
new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3)
|
||||
|
||||
if extensions is None:
|
||||
extensions = []
|
||||
for extension in reversed(extensions):
|
||||
new_frame = extension.decode(new_frame, max_size=max_size)
|
||||
|
||||
new_frame.check()
|
||||
|
||||
return cls(
|
||||
new_frame.fin,
|
||||
new_frame.opcode,
|
||||
new_frame.data,
|
||||
new_frame.rsv1,
|
||||
new_frame.rsv2,
|
||||
new_frame.rsv3,
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
write: Callable[[bytes], Any],
|
||||
*,
|
||||
mask: bool,
|
||||
extensions: Sequence[extensions.Extension] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write a WebSocket frame.
|
||||
|
||||
Args:
|
||||
frame: Frame to write.
|
||||
write: Function that writes bytes.
|
||||
mask: Whether the frame should be masked i.e. whether the write
|
||||
happens on the client side.
|
||||
extensions: List of extensions, applied in order.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If the frame contains incorrect values.
|
||||
|
||||
"""
|
||||
# The frame is written in a single call to write in order to prevent
|
||||
# TCP fragmentation. See #68 for details. This also makes it safe to
|
||||
# send frames concurrently from multiple coroutines.
|
||||
write(self.new_frame.serialize(mask=mask, extensions=extensions))
|
||||
|
||||
|
||||
def prepare_data(data: Data) -> tuple[int, bytes]:
|
||||
"""
|
||||
Convert a string or byte-like object to an opcode and a bytes-like object.
|
||||
|
||||
This function is designed for data frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return ``OP_TEXT`` and a :class:`bytes`
|
||||
object encoding ``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return ``OP_BINARY`` and a bytes-like
|
||||
object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return frames.Opcode.TEXT, data.encode()
|
||||
elif isinstance(data, BytesLike):
|
||||
return frames.Opcode.BINARY, data
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
|
||||
def prepare_ctrl(data: Data) -> bytes:
|
||||
"""
|
||||
Convert a string or byte-like object to bytes.
|
||||
|
||||
This function is designed for ping and pong frames.
|
||||
|
||||
If ``data`` is a :class:`str`, return a :class:`bytes` object encoding
|
||||
``data`` in UTF-8.
|
||||
|
||||
If ``data`` is a bytes-like object, return a :class:`bytes` object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ``data`` doesn't have a supported type.
|
||||
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
return data.encode()
|
||||
elif isinstance(data, BytesLike):
|
||||
return bytes(data)
|
||||
else:
|
||||
raise TypeError("data must be str or bytes-like")
|
||||
|
||||
|
||||
# Backwards compatibility with previously documented public APIs
|
||||
encode_data = prepare_ctrl
|
||||
|
||||
# Backwards compatibility with previously documented public APIs
|
||||
from ..frames import Close # noqa: E402 F401, I001
|
||||
|
||||
|
||||
def parse_close(data: bytes) -> tuple[int, str]:
|
||||
"""
|
||||
Parse the payload from a close frame.
|
||||
|
||||
Returns:
|
||||
Close code and reason.
|
||||
|
||||
Raises:
|
||||
ProtocolError: If data is ill-formed.
|
||||
UnicodeDecodeError: If the reason isn't valid UTF-8.
|
||||
|
||||
"""
|
||||
close = Close.parse(data)
|
||||
return close.code, close.reason
|
||||
|
||||
|
||||
def serialize_close(code: int, reason: str) -> bytes:
|
||||
"""
|
||||
Serialize the payload for a close frame.
|
||||
|
||||
"""
|
||||
return Close(code, reason).serialize()
|
158
venv/lib/python3.12/site-packages/websockets/legacy/handshake.py
Normal file
158
venv/lib/python3.12/site-packages/websockets/legacy/handshake.py
Normal file
@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
|
||||
from ..datastructures import Headers, MultipleValuesError
|
||||
from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
|
||||
from ..headers import parse_connection, parse_upgrade
|
||||
from ..typing import ConnectionOption, UpgradeProtocol
|
||||
from ..utils import accept_key as accept, generate_key
|
||||
|
||||
|
||||
__all__ = ["build_request", "check_request", "build_response", "check_response"]
|
||||
|
||||
|
||||
def build_request(headers: Headers) -> str:
|
||||
"""
|
||||
Build a handshake request to send to the server.
|
||||
|
||||
Update request headers passed in argument.
|
||||
|
||||
Args:
|
||||
headers: Handshake request headers.
|
||||
|
||||
Returns:
|
||||
``key`` that must be passed to :func:`check_response`.
|
||||
|
||||
"""
|
||||
key = generate_key()
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Key"] = key
|
||||
headers["Sec-WebSocket-Version"] = "13"
|
||||
return key
|
||||
|
||||
|
||||
def check_request(headers: Headers) -> str:
|
||||
"""
|
||||
Check a handshake request received from the client.
|
||||
|
||||
This function doesn't verify that the request is an HTTP/1.1 or higher GET
|
||||
request and doesn't perform ``Host`` and ``Origin`` checks. These controls
|
||||
are usually performed earlier in the HTTP request handling code. They're
|
||||
the responsibility of the caller.
|
||||
|
||||
Args:
|
||||
headers: Handshake request headers.
|
||||
|
||||
Returns:
|
||||
``key`` that must be passed to :func:`build_response`.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake request is invalid.
|
||||
Then, the server must return a 400 Bad Request error.
|
||||
|
||||
"""
|
||||
connection: list[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade("Connection", ", ".join(connection))
|
||||
|
||||
upgrade: list[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
|
||||
|
||||
try:
|
||||
s_w_key = headers["Sec-WebSocket-Key"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Key", "multiple values") from exc
|
||||
|
||||
try:
|
||||
raw_key = base64.b64decode(s_w_key.encode(), validate=True)
|
||||
except binascii.Error as exc:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc
|
||||
if len(raw_key) != 16:
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)
|
||||
|
||||
try:
|
||||
s_w_version = headers["Sec-WebSocket-Version"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Version", "multiple values") from exc
|
||||
|
||||
if s_w_version != "13":
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)
|
||||
|
||||
return s_w_key
|
||||
|
||||
|
||||
def build_response(headers: Headers, key: str) -> None:
|
||||
"""
|
||||
Build a handshake response to send to the client.
|
||||
|
||||
Update response headers passed in argument.
|
||||
|
||||
Args:
|
||||
headers: Handshake response headers.
|
||||
key: Returned by :func:`check_request`.
|
||||
|
||||
"""
|
||||
headers["Upgrade"] = "websocket"
|
||||
headers["Connection"] = "Upgrade"
|
||||
headers["Sec-WebSocket-Accept"] = accept(key)
|
||||
|
||||
|
||||
def check_response(headers: Headers, key: str) -> None:
|
||||
"""
|
||||
Check a handshake response received from the server.
|
||||
|
||||
This function doesn't verify that the response is an HTTP/1.1 or higher
|
||||
response with a 101 status code. These controls are the responsibility of
|
||||
the caller.
|
||||
|
||||
Args:
|
||||
headers: Handshake response headers.
|
||||
key: Returned by :func:`build_request`.
|
||||
|
||||
Raises:
|
||||
InvalidHandshake: If the handshake response is invalid.
|
||||
|
||||
"""
|
||||
connection: list[ConnectionOption] = sum(
|
||||
[parse_connection(value) for value in headers.get_all("Connection")], []
|
||||
)
|
||||
|
||||
if not any(value.lower() == "upgrade" for value in connection):
|
||||
raise InvalidUpgrade("Connection", " ".join(connection))
|
||||
|
||||
upgrade: list[UpgradeProtocol] = sum(
|
||||
[parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
||||
)
|
||||
|
||||
# For compatibility with non-strict implementations, ignore case when
|
||||
# checking the Upgrade header. The RFC always uses "websocket", except
|
||||
# in section 11.2. (IANA registration) where it uses "WebSocket".
|
||||
if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
|
||||
raise InvalidUpgrade("Upgrade", ", ".join(upgrade))
|
||||
|
||||
try:
|
||||
s_w_accept = headers["Sec-WebSocket-Accept"]
|
||||
except KeyError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept") from exc
|
||||
except MultipleValuesError as exc:
|
||||
raise InvalidHeader("Sec-WebSocket-Accept", "multiple values") from exc
|
||||
|
||||
if s_w_accept != accept(key):
|
||||
raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
|
201
venv/lib/python3.12/site-packages/websockets/legacy/http.py
Normal file
201
venv/lib/python3.12/site-packages/websockets/legacy/http.py
Normal file
@ -0,0 +1,201 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
|
||||
from ..datastructures import Headers
|
||||
from ..exceptions import SecurityError
|
||||
|
||||
|
||||
__all__ = ["read_request", "read_response"]
|
||||
|
||||
MAX_NUM_HEADERS = int(os.environ.get("WEBSOCKETS_MAX_NUM_HEADERS", "128"))
|
||||
MAX_LINE_LENGTH = int(os.environ.get("WEBSOCKETS_MAX_LINE_LENGTH", "8192"))
|
||||
|
||||
|
||||
def d(value: bytes) -> str:
|
||||
"""
|
||||
Decode a bytestring for interpolating into an error message.
|
||||
|
||||
"""
|
||||
return value.decode(errors="backslashreplace")
|
||||
|
||||
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7230#appendix-B.
|
||||
|
||||
# Regex for validating header names.
|
||||
|
||||
_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+")
|
||||
|
||||
# Regex for validating header values.
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff).
|
||||
|
||||
# The ABNF is complicated because it attempts to express that optional
|
||||
# whitespace is ignored. We strip whitespace and don't revalidate that.
|
||||
|
||||
# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189
|
||||
|
||||
_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*")
|
||||
|
||||
|
||||
async def read_request(stream: asyncio.StreamReader) -> tuple[str, Headers]:
|
||||
"""
|
||||
Read an HTTP/1.1 GET request and return ``(path, headers)``.
|
||||
|
||||
``path`` isn't URL-decoded or validated in any way.
|
||||
|
||||
``path`` and ``headers`` are expected to contain only ASCII characters.
|
||||
Other characters are represented with surrogate escapes.
|
||||
|
||||
:func:`read_request` doesn't attempt to read the request body because
|
||||
WebSocket handshake requests don't have one. If the request contains a
|
||||
body, it may be read from ``stream`` after this coroutine returns.
|
||||
|
||||
Args:
|
||||
stream: Input to read the request from.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a full HTTP request.
|
||||
SecurityError: If the request exceeds a security limit.
|
||||
ValueError: If the request isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.1
|
||||
|
||||
# Parsing is simple because fixed values are expected for method and
|
||||
# version and because path isn't checked. Since WebSocket software tends
|
||||
# to implement HTTP/1.1 strictly, there's little need for lenient parsing.
|
||||
|
||||
try:
|
||||
request_line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP request line") from exc
|
||||
|
||||
try:
|
||||
method, raw_path, version = request_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None
|
||||
|
||||
if method != b"GET":
|
||||
raise ValueError(f"unsupported HTTP method: {d(method)}")
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
path = raw_path.decode("ascii", "surrogateescape")
|
||||
|
||||
headers = await read_headers(stream)
|
||||
|
||||
return path, headers
|
||||
|
||||
|
||||
async def read_response(stream: asyncio.StreamReader) -> tuple[int, str, Headers]:
|
||||
"""
|
||||
Read an HTTP/1.1 response and return ``(status_code, reason, headers)``.
|
||||
|
||||
``reason`` and ``headers`` are expected to contain only ASCII characters.
|
||||
Other characters are represented with surrogate escapes.
|
||||
|
||||
:func:`read_request` doesn't attempt to read the response body because
|
||||
WebSocket handshake responses don't have one. If the response contains a
|
||||
body, it may be read from ``stream`` after this coroutine returns.
|
||||
|
||||
Args:
|
||||
stream: Input to read the response from.
|
||||
|
||||
Raises:
|
||||
EOFError: If the connection is closed without a full HTTP response.
|
||||
SecurityError: If the response exceeds a security limit.
|
||||
ValueError: If the response isn't well formatted.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.1.2
|
||||
|
||||
# As in read_request, parsing is simple because a fixed value is expected
|
||||
# for version, status_code is a 3-digit number, and reason can be ignored.
|
||||
|
||||
try:
|
||||
status_line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP status line") from exc
|
||||
|
||||
try:
|
||||
version, raw_status_code, raw_reason = status_line.split(b" ", 2)
|
||||
except ValueError: # not enough values to unpack (expected 3, got 1-2)
|
||||
raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None
|
||||
|
||||
if version != b"HTTP/1.1":
|
||||
raise ValueError(f"unsupported HTTP version: {d(version)}")
|
||||
try:
|
||||
status_code = int(raw_status_code)
|
||||
except ValueError: # invalid literal for int() with base 10
|
||||
raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None
|
||||
if not 100 <= status_code < 1000:
|
||||
raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}")
|
||||
if not _value_re.fullmatch(raw_reason):
|
||||
raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}")
|
||||
reason = raw_reason.decode()
|
||||
|
||||
headers = await read_headers(stream)
|
||||
|
||||
return status_code, reason, headers
|
||||
|
||||
|
||||
async def read_headers(stream: asyncio.StreamReader) -> Headers:
|
||||
"""
|
||||
Read HTTP headers from ``stream``.
|
||||
|
||||
Non-ASCII characters are represented with surrogate escapes.
|
||||
|
||||
"""
|
||||
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.2
|
||||
|
||||
# We don't attempt to support obsolete line folding.
|
||||
|
||||
headers = Headers()
|
||||
for _ in range(MAX_NUM_HEADERS + 1):
|
||||
try:
|
||||
line = await read_line(stream)
|
||||
except EOFError as exc:
|
||||
raise EOFError("connection closed while reading HTTP headers") from exc
|
||||
if line == b"":
|
||||
break
|
||||
|
||||
try:
|
||||
raw_name, raw_value = line.split(b":", 1)
|
||||
except ValueError: # not enough values to unpack (expected 2, got 1)
|
||||
raise ValueError(f"invalid HTTP header line: {d(line)}") from None
|
||||
if not _token_re.fullmatch(raw_name):
|
||||
raise ValueError(f"invalid HTTP header name: {d(raw_name)}")
|
||||
raw_value = raw_value.strip(b" \t")
|
||||
if not _value_re.fullmatch(raw_value):
|
||||
raise ValueError(f"invalid HTTP header value: {d(raw_value)}")
|
||||
|
||||
name = raw_name.decode("ascii") # guaranteed to be ASCII at this point
|
||||
value = raw_value.decode("ascii", "surrogateescape")
|
||||
headers[name] = value
|
||||
|
||||
else:
|
||||
raise SecurityError("too many HTTP headers")
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
async def read_line(stream: asyncio.StreamReader) -> bytes:
|
||||
"""
|
||||
Read a single line from ``stream``.
|
||||
|
||||
CRLF is stripped from the return value.
|
||||
|
||||
"""
|
||||
# Security: this is bounded by the StreamReader's limit (default = 32 KiB).
|
||||
line = await stream.readline()
|
||||
# Security: this guarantees header values are small (hard-coded = 8 KiB)
|
||||
if len(line) > MAX_LINE_LENGTH:
|
||||
raise SecurityError("line too long")
|
||||
# Not mandatory but safe - https://datatracker.ietf.org/doc/html/rfc7230#section-3.5
|
||||
if not line.endswith(b"\r\n"):
|
||||
raise EOFError("line without CRLF")
|
||||
return line[:-2]
|
1641
venv/lib/python3.12/site-packages/websockets/legacy/protocol.py
Normal file
1641
venv/lib/python3.12/site-packages/websockets/legacy/protocol.py
Normal file
File diff suppressed because it is too large
Load Diff
1190
venv/lib/python3.12/site-packages/websockets/legacy/server.py
Normal file
1190
venv/lib/python3.12/site-packages/websockets/legacy/server.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user