-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path_ainternal.py
More file actions
66 lines (48 loc) · 1.72 KB
/
Copy path_ainternal.py
File metadata and controls
66 lines (48 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Shared async utility functions for the MySQL checkpoint classes."""
from __future__ import annotations
from collections.abc import AsyncIterator, Mapping, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import (
Any,
Generic,
Protocol,
TypeVar,
Union,
cast,
)
class AsyncDictCursor(AbstractAsyncContextManager, Protocol):
"""Protocol that a cursor should implement."""
async def execute(
self,
operation: str,
parameters: Sequence[Any] | Mapping[str, Any] = ...,
/,
) -> object: ...
async def executemany(
self, operation: str, seq_of_parameters: Sequence[Sequence[Any]], /
) -> object: ...
async def fetchone(self) -> dict[str, Any] | None: ...
async def fetchall(self) -> Sequence[dict[str, Any]]: ...
def __aiter__(self) -> AsyncIterator[dict[str, Any]]: ...
R = TypeVar("R", bound=AsyncDictCursor)
class AsyncConnection(AbstractAsyncContextManager, Protocol):
async def begin(self) -> None: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def set_charset(self, charset: str) -> None: ...
C = TypeVar("C", bound=AsyncConnection)
COut = TypeVar("COut", bound=AsyncConnection, covariant=True)
class AsyncPool(Protocol, Generic[COut]):
def acquire(self) -> COut: ...
Conn = Union[C, AsyncPool[C]]
@asynccontextmanager
async def get_connection(
conn: Conn[C],
) -> AsyncIterator[C]:
if hasattr(conn, "cursor"):
yield cast(C, conn)
elif hasattr(conn, "acquire"):
async with cast(AsyncPool[C], conn).acquire() as _conn:
yield _conn
else:
raise TypeError(f"Invalid connection type: {type(conn)}")