import asyncio
import logging
import ssl
import urllib.parse as urlparse
from typing import Dict, Optional, Type, Union, overload
import aiowamp
from aiowamp import CommonTransportConfig, JSONSerializer, MessagePackSerializer, TransportABC, TransportError, \
register_transport_factory
__all__ = ["RawSocketTransport", "connect_raw_socket",
"get_serializer_protocol"]
log = logging.getLogger(__name__)
MAGIC_OCTET = b"\x7F"
"""Magic bytes sent as the first octet of the handshake."""
[docs]class RawSocketTransport(TransportABC):
"""WAMP transport over raw sockets.
Notes:
The `start` method needs to be called before `recv` can read any messages.
This is done as part of the `perform_client_handshake` procedure.
"""
__slots__ = ("reader", "writer", "serializer",
"_msg_queue",
"__recv_limit", "__send_limit",
"__read_task")
reader: asyncio.StreamReader
"""Reader for the underlying stream."""
writer: asyncio.StreamWriter
"""Writer for the underlying transport."""
serializer: aiowamp.SerializerABC
"""Serializer used to serialise messages."""
_msg_queue: Optional[asyncio.Queue]
__recv_limit: int
__send_limit: int
__read_task: Optional[asyncio.Task]
[docs] def __init__(self, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
serializer: aiowamp.SerializerABC, *,
recv_limit: int,
send_limit: int) -> None:
"""
Args:
reader: Reader to read from.
writer: Writer to write to.
serializer: Serializer to use.
recv_limit: Max amount of bytes willing to receive.
send_limit: Max amount of bytes remote is willing to receive.
See Also:
The `connect_raw_socket` method for opening a connection.
"""
self.reader = reader
self.writer = writer
self.serializer = serializer
self._msg_queue = None
self.__recv_limit = recv_limit
self.__send_limit = send_limit
self.__read_task = None
[docs] def __repr__(self) -> str:
return f"RawSocketTransport(reader={self.reader!r},writer={self.writer!r}," \
f"serializer={self.serializer!r},recv_limit={self.__recv_limit!r}," \
f"send_limit={self.__send_limit!r})"
@property
def open(self) -> bool:
return not (self.reader.at_eof() or
self.reader.exception() or
self.writer.is_closing())
def start(self) -> None:
if self.__read_task and not self.__read_task.done():
raise RuntimeError("read loop already running!")
self._msg_queue = asyncio.Queue()
self.__read_task = asyncio.create_task(self.__read_loop())
[docs] async def close(self) -> None:
log.debug("%s: closing", self)
self.writer.close()
await self.writer.wait_closed()
[docs] async def send(self, msg: aiowamp.MessageABC) -> None:
log.debug("%s: sending: %r", self, msg)
data = self.serializer.serialize(msg)
if len(data) > self.__send_limit:
raise TransportError("message longer than remote is willing to receive")
header = b"\x00" + int_to_bytes(len(data))
if log.isEnabledFor(logging.DEBUG):
log.debug("%s: writing header: %s", self, header.hex())
import binascii
log.debug("%s: writing data: %s", self, binascii.b2a_base64(data))
self.writer.write(header)
self.writer.write(data)
await self.writer.drain()
async def __read_once(self) -> None:
assert self._msg_queue
try:
header = await self.reader.readexactly(4)
except asyncio.IncompleteReadError:
return
length = bytes_to_int(header[1:])
if length > self.__recv_limit:
await self.close()
raise TransportError("received message bigger than receive limit")
t_type = header[0]
# regular WAMP message
if t_type == 0:
# read message and add to queue
data = await self.reader.readexactly(length)
msg = self.serializer.deserialize(data)
log.debug("%s: received message: %r", self, msg)
await self._msg_queue.put(msg)
# PING
elif t_type == 1:
log.debug("%s: received PING, sending PONG", self)
# send header with t_type = PONG
self.writer.write(b"\02" + header[1:])
# echo body
self.writer.write(await self.reader.readexactly(length))
await self.writer.drain()
# PONG
elif t_type == 2:
log.debug("%s: received PONG", self)
# discard body
await self.reader.readexactly(length)
else:
log.warning("%s: received header with unknown op code: %s", self, t_type)
await self.reader.readexactly(length)
async def __read_loop(self) -> None:
log.debug("%s: starting read loop", self)
while self.open:
try:
await self.__read_once()
except asyncio.CancelledError:
break
except Exception:
log.exception("%s: error while reading once", self)
log.debug("%s: exiting read loop", self)
[docs] async def recv(self) -> aiowamp.MessageABC:
try:
return await self._msg_queue.get() # type: ignore
except AttributeError:
raise RuntimeError("cannot receive message before message loop is started.") from None
def int_to_bytes(i: int) -> bytes:
"""Convert an integer to its WAMP bytes representation.
Args:
i: Integer to convert.
Returns:
Byte representation.
"""
return i.to_bytes(3, "big", signed=False)
def bytes_to_int(d: bytes) -> int:
"""Convert the WAMP byte representation to an int.
Args:
d: Bytes to convert.
Returns:
Integer value.
"""
return int.from_bytes(d, "big", signed=False)
def byte_limit_to_size(limit: int) -> int:
return 1 << (limit + 9)
def size_to_byte_limit(recv_limit: int) -> int:
if recv_limit > 0:
for l in range(0xf + 1):
if byte_limit_to_size(l) >= recv_limit:
return l
return 0xf
HANDSHAKE_ERRCODE_EXCEPTIONS = {
0: TransportError("illegal error code"),
1: TransportError("serializer unsupported"),
2: TransportError("maximum message length unacceptable"),
3: TransportError("use of reserved bits"),
4: TransportError("maximum connection count reached"),
}
async def perform_client_handshake(reader: asyncio.StreamReader, writer: asyncio.StreamWriter,
recv_limit: int, protocol: int, *,
serializer: aiowamp.SerializerABC,
) -> RawSocketTransport:
"""Perform the raw socket client handshake and establish the transport.
Args:
reader: Reader to read from.
writer: Writer to write to.
recv_limit: Receive limit in bytes.
protocol: Serialization protocol to use.
serializer: Serializer to use for the transport.
Returns:
Established raw socket transport.
Raises:
aiowamp.TransportError: When the handshake fails.
"""
recv_byte_limit = size_to_byte_limit(recv_limit)
handshake_data = bytearray(MAGIC_OCTET)
handshake_data.append((recv_byte_limit & 0xf) << 4 | protocol)
handshake_data.extend((0, 0))
if log.isEnabledFor(logging.DEBUG):
log.debug("sending handshake: %s", handshake_data.hex())
writer.write(handshake_data)
await writer.drain()
try:
resp = await reader.readexactly(4)
except asyncio.IncompleteReadError as e:
raise TransportError("remote closed connection during handshake") from e
if log.isEnabledFor(logging.DEBUG):
log.debug("received handshake response: %s", resp.hex())
# use 1-slice to get bytes instead of int
if resp[0:1] != MAGIC_OCTET:
raise TransportError("received invalid magic octet while performing handshake. "
f"Expected {MAGIC_OCTET!r}, got {resp[0]}")
if resp[2:] != b"\x00\x00":
raise TransportError("expected 3rd and 4th octet to be all zeroes (reserved). "
f"Saw {resp[2:].hex()}")
proto_echo = resp[1] & 0xf
# if the first 4 bits are 0 it's an error response
if proto_echo == 0:
error_code = resp[1] >> 4
try:
exc = HANDSHAKE_ERRCODE_EXCEPTIONS[error_code]
except KeyError:
raise TransportError(f"unknown error code: {error_code}") from None
else:
raise exc
# router must echo the protocol
elif proto_echo != protocol:
raise TransportError("router didn't echo protocol. "
f"Expected {protocol}, got {proto_echo}")
recv_limit = byte_limit_to_size(recv_byte_limit)
send_limit = byte_limit_to_size(resp[1] >> 4)
transport = RawSocketTransport(reader, writer, serializer,
recv_limit=recv_limit, send_limit=send_limit)
transport.start()
return transport
def is_secure_scheme(scheme: str) -> bool:
"""Check if the given scheme is secure.
Args:
scheme: Scheme to check
Returns:
Whether the scheme is secure.
"""
return scheme in {"tcps", "tcp4s", "tcp6s", "rss"}
async def _connect(url: Union[str, urlparse.ParseResult], serializer: aiowamp.SerializerABC, *,
ssl_context: Union[ssl.SSLContext, bool] = None,
recv_limit: int = 0) -> RawSocketTransport:
"""Connect to a WAMP router over raw socket.
Args:
url: URL of the router.
serializer: Serializer to use.
ssl_context: Set custom SSL context options.
recv_limit: Receive limit in bytes.
Defaults to the max size of 16mb.
Returns:
THe connected raw socket transport.
"""
if not isinstance(url, urlparse.ParseResult):
url = urlparse.urlparse(url)
if is_secure_scheme(url.scheme):
if not ssl_context:
# create default ssl context
ssl_context = True
elif ssl_context:
raise ValueError(f"SSL context specified for a uri which doesn't use TLS: {url.scheme!r}.")
reader = asyncio.StreamReader()
loop = asyncio.get_running_loop()
transport, protocol = await loop.create_connection(
lambda: asyncio.StreamReaderProtocol(reader),
host=url.hostname,
port=url.port,
ssl=ssl_context,
)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
log.debug("performing handshake")
try:
return await perform_client_handshake(
reader, writer, recv_limit, get_serializer_protocol(serializer),
serializer=serializer,
)
except Exception:
# don't wait for the connection to close
writer.close()
raise
@register_transport_factory("tcp", "tcps",
"tcp4", "tcp4s",
"tcp6", "tcp6s",
"rs", "rss")
async def _connect_config(config: aiowamp.CommonTransportConfig) -> RawSocketTransport:
return await _connect(config.url, config.serializer or JSONSerializer(),
ssl_context=config.ssl_context)
@overload
async def connect_raw_socket(url: Union[str, urlparse.ParseResult], serializer: aiowamp.SerializerABC, *,
ssl_context: Union[ssl.SSLContext, bool] = None,
recv_limit: int = 0) -> RawSocketTransport: ...
@overload
async def connect_raw_socket(config: aiowamp.CommonTransportConfig) -> RawSocketTransport: ...
async def connect_raw_socket(*args, **kwargs) -> RawSocketTransport:
if args and isinstance(args[0], CommonTransportConfig):
return await _connect_config(*args, **kwargs)
return await _connect(*args, **kwargs)
_BUILTIN_PROTOCOLS: Dict[Type[aiowamp.SerializerABC], int] = {
JSONSerializer: 1,
MessagePackSerializer: 2,
}
def get_serializer_protocol(serializer: aiowamp.SerializerABC) -> int:
return _BUILTIN_PROTOCOLS[type(serializer)]