blob: 3f15654a1af4678a829aa0a1311d9c198282cd86 [file] [log] [blame]
import asyncio
import contextvars
import unittest
import sys
from unittest import TestCase
try:
import ssl
except ImportError:
ssl = None
from test.test_asyncio import utils as test_utils
def tearDownModule():
asyncio.events._set_event_loop_policy(None)
class ServerContextvarsTestCase:
loop_factory = None # To be defined in subclasses
server_ssl_context = None # To be defined in subclasses for SSL tests
client_ssl_context = None # To be defined in subclasses for SSL tests
def run_coro(self, coro):
return asyncio.run(coro, loop_factory=self.loop_factory)
def test_start_server1(self):
# Test that asyncio.start_server captures the context at the time of server creation
async def test():
var = contextvars.ContextVar("var", default="default")
async def handle_client(reader, writer):
value = var.get()
writer.write(value.encode())
await writer.drain()
writer.close()
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
ssl=self.server_ssl_context)
# change the value
var.set("after_server")
async def client(addr):
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
writer.close()
await writer.wait_closed()
return data.decode()
async with server:
addr = server.sockets[0].getsockname()
self.assertEqual(await client(addr), "default")
self.assertEqual(var.get(), "after_server")
self.run_coro(test())
def test_start_server2(self):
# Test that mutations to the context in one handler don't affect other handlers or the server's context
async def test():
var = contextvars.ContextVar("var", default="default")
async def handle_client(reader, writer):
value = var.get()
writer.write(value.encode())
var.set("in_handler")
await writer.drain()
writer.close()
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
ssl=self.server_ssl_context)
var.set("after_server")
async def client(addr):
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
writer.close()
await writer.wait_closed()
return data.decode()
async with server:
addr = server.sockets[0].getsockname()
self.assertEqual(await client(addr), "default")
self.assertEqual(await client(addr), "default")
self.assertEqual(await client(addr), "default")
self.assertEqual(var.get(), "after_server")
self.run_coro(test())
def test_start_server3(self):
# Test that mutations to context in concurrent handlers don't affect each other or the server's context
async def test():
var = contextvars.ContextVar("var", default="default")
var.set("before_server")
async def handle_client(reader, writer):
writer.write(var.get().encode())
await writer.drain()
writer.close()
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
ssl=self.server_ssl_context)
var.set("after_server")
async def client(addr):
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
self.assertEqual(data.decode(), "before_server")
writer.close()
await writer.wait_closed()
async with server:
addr = server.sockets[0].getsockname()
async with asyncio.TaskGroup() as tg:
for _ in range(100):
tg.create_task(client(addr))
self.assertEqual(var.get(), "after_server")
self.run_coro(test())
def test_create_server1(self):
# Test that loop.create_server captures the context at the time of server creation
# and that mutations to the context in protocol callbacks don't affect the server's context
async def test():
var = contextvars.ContextVar("var", default="default")
class EchoProtocol(asyncio.Protocol):
def connection_made(self, transport):
self.transport = transport
value = var.get()
var.set("in_handler")
self.transport.write(value.encode())
self.transport.close()
server = await asyncio.get_running_loop().create_server(
lambda: EchoProtocol(), '127.0.0.1', 0,
ssl=self.server_ssl_context)
var.set("after_server")
async def client(addr):
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
self.assertEqual(data.decode(), "default")
writer.close()
await writer.wait_closed()
async with server:
addr = server.sockets[0].getsockname()
await client(addr)
self.assertEqual(var.get(), "after_server")
self.run_coro(test())
def test_create_server2(self):
# Test that mutations to context in one protocol instance don't affect other instances or the server's context
async def test():
var = contextvars.ContextVar("var", default="default")
class EchoProtocol(asyncio.Protocol):
def __init__(self):
super().__init__()
assert var.get() == "default", var.get()
def connection_made(self, transport):
self.transport = transport
value = var.get()
var.set("in_handler")
self.transport.write(value.encode())
self.transport.close()
server = await asyncio.get_running_loop().create_server(
lambda: EchoProtocol(), '127.0.0.1', 0,
ssl=self.server_ssl_context)
var.set("after_server")
async def client(addr, expected):
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
self.assertEqual(data.decode(), expected)
writer.close()
await writer.wait_closed()
async with server:
addr = server.sockets[0].getsockname()
await client(addr, "default")
await client(addr, "default")
self.assertEqual(var.get(), "after_server")
self.run_coro(test())
def test_gh140947(self):
# See https://github.com/python/cpython/issues/140947
cvar1 = contextvars.ContextVar("cvar1")
cvar2 = contextvars.ContextVar("cvar2")
cvar3 = contextvars.ContextVar("cvar3")
results = {}
is_ssl = self.server_ssl_context is not None
def capture_context(meth):
result = []
for k,v in contextvars.copy_context().items():
if k.name.startswith("cvar"):
result.append((k.name, v))
results[meth] = sorted(result)
class DemoProtocol(asyncio.Protocol):
def __init__(self, on_conn_lost):
self.transport = None
self.on_conn_lost = on_conn_lost
self.tasks = set()
def connection_made(self, transport):
capture_context("connection_made")
self.transport = transport
def data_received(self, data):
capture_context("data_received")
task = asyncio.create_task(self.asgi())
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
self.transport.pause_reading()
def connection_lost(self, exc):
capture_context("connection_lost")
if not self.on_conn_lost.done():
self.on_conn_lost.set_result(True)
async def asgi(self):
capture_context("asgi start")
cvar1.set(True)
# make sure that we only resume after the pause
# otherwise the resume does nothing
if is_ssl:
while not self.transport._ssl_protocol._app_reading_paused:
await asyncio.sleep(0.01)
else:
while not self.transport._paused:
await asyncio.sleep(0.01)
cvar2.set(True)
self.transport.resume_reading()
cvar3.set(True)
capture_context("asgi end")
async def main():
loop = asyncio.get_running_loop()
on_conn_lost = loop.create_future()
server = await loop.create_server(
lambda: DemoProtocol(on_conn_lost), '127.0.0.1', 0,
ssl=self.server_ssl_context)
async with server:
addr = server.sockets[0].getsockname()
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
writer.write(b"anything")
await writer.drain()
writer.close()
await writer.wait_closed()
await on_conn_lost
self.run_coro(main())
self.assertDictEqual(results, {
"connection_made": [],
"data_received": [],
"asgi start": [],
"asgi end": [("cvar1", True), ("cvar2", True), ("cvar3", True)],
"connection_lost": [],
})
class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase):
loop_factory = staticmethod(asyncio.new_event_loop)
@unittest.skipUnless(ssl, "SSL not available")
class AsyncioEventLoopSSLTests(AsyncioEventLoopTests):
def setUp(self):
super().setUp()
self.server_ssl_context = test_utils.simple_server_sslcontext()
self.client_ssl_context = test_utils.simple_client_sslcontext()
if sys.platform == "win32":
class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase):
loop_factory = asyncio.ProactorEventLoop
class AsyncioSelectorEventLoopTests(TestCase, ServerContextvarsTestCase):
loop_factory = asyncio.SelectorEventLoop
@unittest.skipUnless(ssl, "SSL not available")
class AsyncioProactorEventLoopSSLTests(AsyncioProactorEventLoopTests):
def setUp(self):
super().setUp()
self.server_ssl_context = test_utils.simple_server_sslcontext()
self.client_ssl_context = test_utils.simple_client_sslcontext()
@unittest.skipUnless(ssl, "SSL not available")
class AsyncioSelectorEventLoopSSLTests(AsyncioSelectorEventLoopTests):
def setUp(self):
super().setUp()
self.server_ssl_context = test_utils.simple_server_sslcontext()
self.client_ssl_context = test_utils.simple_client_sslcontext()
if __name__ == "__main__":
unittest.main()