Switch to asyncpg/aiosqlite

Fixes #142
Fixes #98
Probably fixes #62
This commit is contained in:
Tulir Asokan 2022-03-25 19:45:48 +02:00
parent 068e268c63
commit 21ed971d2f
43 changed files with 911 additions and 955 deletions

View file

@ -15,14 +15,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast
from collections import defaultdict
import asyncio
import logging
from aiohttp import ClientSession
from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (
ContentURI,
@ -41,69 +41,110 @@ from mautrix.types import (
SyncToken,
UserID,
)
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.logging import TraceLogger
from .db import DBClient
from .lib.store_proxy import SyncStoreProxy
from .db import Client as DBClient
from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
from mautrix.util.async_db import Database as AsyncDatabase
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
from mautrix.crypto import OlmMachine, PgCryptoStore
crypto_import_error = None
except ImportError as e:
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
SQLStateStore = BaseSQLStateStore
OlmMachine = PgCryptoStore = None
crypto_import_error = e
if TYPE_CHECKING:
from .config import Config
from .__main__ import Maubot
from .instance import PluginInstance
log = logging.getLogger("maubot.client")
class Client:
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
class Client(DBClient):
maubot: "Maubot" = None
cache: dict[UserID, Client] = {}
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
log: TraceLogger = logging.getLogger("maubot.client")
http_client: ClientSession = None
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
crypto_db: AsyncDatabase | None = None
references: set[PluginInstance]
db_instance: DBClient
client: MaubotMatrixClient
crypto: OlmMachine | None
crypto_store: PgCryptoStore | None
started: bool
sync_ok: bool
remote_displayname: str | None
remote_avatar_url: ContentURI | None
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
def __init__(
self,
id: UserID,
homeserver: str,
access_token: str,
device_id: DeviceID,
enabled: bool = False,
next_batch: SyncToken = "",
filter_id: FilterID = "",
sync: bool = True,
autojoin: bool = True,
online: bool = True,
displayname: str = "disable",
avatar_url: str = "disable",
) -> None:
super().__init__(
id=id,
homeserver=homeserver,
access_token=access_token,
device_id=device_id,
enabled=bool(enabled),
next_batch=next_batch,
filter_id=filter_id,
sync=bool(sync),
autojoin=bool(autojoin),
online=bool(online),
displayname=displayname,
avatar_url=avatar_url,
)
self._postinited = False
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def _make_client(
self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
) -> MaubotMatrixClient:
return MaubotMatrixClient(
mxid=self.id,
base_url=homeserver or self.homeserver,
token=token or self.access_token,
client_session=self.http_client,
log=self.log,
crypto_log=self.log.getChild("crypto"),
loop=self.maubot.loop,
device_id=device_id or self.device_id,
sync_store=self,
state_store=self.maubot.state_store,
)
def postinit(self) -> None:
if self._postinited:
raise RuntimeError("postinit() called twice")
self._postinited = True
self.cache[self.id] = self
self.log = log.getChild(self.id)
self.log = self.log.getChild(self.id)
self.http_client = ClientSession(loop=self.maubot.loop)
self.references = set()
self.started = False
self.sync_ok = True
self.remote_displayname = None
self.remote_avatar_url = None
self.client = MaubotMatrixClient(
mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=self.log,
loop=self.loop,
device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store,
)
self.client = self._make_client()
if self.enable_crypto:
self._prepare_crypto()
else:
@ -118,6 +159,12 @@ class Client:
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok
return handler
@property
def enable_crypto(self) -> bool:
if not self.device_id:
@ -131,16 +178,21 @@ class Client:
# Clear the stack trace after it's logged once to avoid spamming logs
crypto_import_error = None
return False
elif not self.crypto_db:
elif not self.maubot.crypto_db:
self.log.warning("Client has device ID, but crypto database is not prepared")
return False
return True
def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
)
self.crypto = OlmMachine(
self.client,
self.crypto_store,
self.maubot.state_store,
log=self.client.crypto_log,
)
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
def _remove_crypto_event_handlers(self) -> None:
@ -156,12 +208,6 @@ class Client:
for event_type, func in handlers:
self.client.remove_event_handler(event_type, func)
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok
return handler
async def start(self, try_n: int | None = 0) -> None:
try:
if try_n > 0:
@ -196,47 +242,50 @@ class Client:
whoami = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
self.db_instance.enabled = False
self.enabled = False
await self.update()
return
except Exception as e:
if try_n >= 8:
self.log.exception("Failed to get /account/whoami, disabling client")
self.db_instance.enabled = False
self.enabled = False
await self.update()
else:
self.log.warning(
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}"
)
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
_ = asyncio.create_task(self.start(try_n + 1))
return
if whoami.user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
self.db_instance.enabled = False
self.enabled = False
await self.update()
return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
)
self.db_instance.enabled = False
self.enabled = False
await self.update()
return
if not self.filter_id:
self.db_instance.edit(
filter_id=await self.client.create_filter(
Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
lazy_load_members=True,
),
state=StateFilter(
lazy_load_members=True,
),
self.filter_id = await self.client.create_filter(
Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
lazy_load_members=True,
),
presence=EventFilter(
not_types=[EventType.PRESENCE],
state=StateFilter(
lazy_load_members=True,
),
)
),
presence=EventFilter(
not_types=[EventType.PRESENCE],
),
)
)
await self.update()
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
@ -270,18 +319,13 @@ class Client:
if self.crypto:
await self.crypto_store.close()
def clear_cache(self) -> None:
async def clear_cache(self) -> None:
self.stop_sync()
self.db_instance.edit(filter_id="", next_batch="")
self.filter_id = FilterID("")
self.next_batch = SyncToken("")
await self.update()
self.start_sync()
def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
self.db_instance.delete()
def to_dict(self) -> dict:
return {
"id": self.id,
@ -304,20 +348,6 @@ class Client:
"instances": [instance.to_dict() for instance in self.references],
}
@classmethod
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
db_instance = db_instance or DBClient.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
def all(cls) -> Iterable[Client]:
return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_tombstone(self, evt: StateEvent) -> None:
if not evt.content.replacement_room:
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
@ -329,7 +359,7 @@ class Client:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room(evt.room_id)
async def update_started(self, started: bool) -> None:
async def update_started(self, started: bool | None) -> None:
if started is None or started == self.started:
return
if started:
@ -337,23 +367,65 @@ class Client:
else:
await self.stop()
async def update_displayname(self, displayname: str) -> None:
async def update_enabled(self, enabled: bool | None, save: bool = True) -> None:
if enabled is None or enabled == self.enabled:
return
self.enabled = enabled
if save:
await self.update()
async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
if displayname is None or displayname == self.displayname:
return
self.db_instance.displayname = displayname
self.displayname = displayname
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
else:
await self._update_remote_profile()
if save:
await self.update()
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None:
if avatar_url is None or avatar_url == self.avatar_url:
return
self.db_instance.avatar_url = avatar_url
self.avatar_url = avatar_url
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
else:
await self._update_remote_profile()
if save:
await self.update()
async def update_sync(self, sync: bool | None, save: bool = True) -> None:
if sync is None or self.sync == sync:
return
self.sync = sync
if self.started:
if sync:
self.start_sync()
else:
self.stop_sync()
if save:
await self.update()
async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
if autojoin is None or autojoin == self.autojoin:
return
if autojoin:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.autojoin = autojoin
if save:
await self.update()
async def update_online(self, online: bool | None, save: bool = True) -> None:
if online is None or online == self.online:
return
self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
self.online = online
if save:
await self.update()
async def update_access_details(
self,
@ -373,22 +445,13 @@ class Client:
and device_id == self.device_id
):
return
new_client = MaubotMatrixClient(
mxid=self.id,
base_url=homeserver or self.homeserver,
token=access_token or self.access_token,
loop=self.loop,
device_id=device_id,
client_session=self.http_client,
log=self.log,
state_store=self.global_state_store,
)
new_client = self._make_client(homeserver, access_token, device_id)
whoami = await new_client.whoami()
if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {whoami.user_id}")
elif whoami.device_id and device_id and whoami.device_id != device_id:
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
new_client.sync_store = SyncStoreProxy(self.db_instance)
new_client.sync_store = self
self.stop_sync()
# TODO this event handler transfer is pretty hacky
@ -398,9 +461,9 @@ class Client:
new_client.global_event_handlers = self.client.global_event_handlers
self.client = new_client
self.db_instance.homeserver = homeserver
self.db_instance.access_token = access_token
self.db_instance.device_id = device_id
self.homeserver = homeserver
self.access_token = access_token
self.device_id = device_id
if self.enable_crypto:
self._prepare_crypto()
await self._start_crypto()
@ -413,97 +476,53 @@ class Client:
profile = await self.client.get_profile(self.id)
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
# region Properties
async def delete(self) -> None:
try:
del self.cache[self.id]
except KeyError:
pass
await super().delete()
@property
def id(self) -> UserID:
return self.db_instance.id
@classmethod
@async_getter_lock
async def get(
cls,
user_id: UserID,
*,
homeserver: str | None = None,
access_token: str | None = None,
device_id: DeviceID | None = None,
) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
pass
@property
def homeserver(self) -> str:
return self.db_instance.homeserver
user = cast(cls, await super().get(user_id))
if user is not None:
user.postinit()
return user
@property
def access_token(self) -> str:
return self.db_instance.access_token
if homeserver and access_token:
user = cls(
user_id,
homeserver=homeserver,
access_token=access_token,
device_id=device_id or "",
)
await user.insert()
user.postinit()
return user
@property
def device_id(self) -> DeviceID:
return self.db_instance.device_id
return None
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@enabled.setter
def enabled(self, value: bool) -> None:
self.db_instance.enabled = value
@property
def next_batch(self) -> SyncToken:
return self.db_instance.next_batch
@property
def filter_id(self) -> FilterID:
return self.db_instance.filter_id
@property
def sync(self) -> bool:
return self.db_instance.sync
@sync.setter
def sync(self, value: bool) -> None:
if value == self.db_instance.sync:
return
self.db_instance.sync = value
if self.started:
if value:
self.start_sync()
else:
self.stop_sync()
@property
def autojoin(self) -> bool:
return self.db_instance.autojoin
@autojoin.setter
def autojoin(self, value: bool) -> None:
if value == self.db_instance.autojoin:
return
if value:
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
else:
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
self.db_instance.autojoin = value
@property
def online(self) -> bool:
return self.db_instance.online
@online.setter
def online(self, value: bool) -> None:
self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE
self.db_instance.online = value
@property
def displayname(self) -> str:
return self.db_instance.displayname
@property
def avatar_url(self) -> ContentURI:
return self.db_instance.avatar_url
# endregion
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
if OlmMachine:
db_url = config["crypto_database"]
if db_url == "default":
db_url = config["database"]
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
return Client.all()
@classmethod
async def all(cls) -> AsyncGenerator[Client, None]:
users = await super().all()
user: cls
for user in users:
try:
yield cls.cache[user.id]
except KeyError:
user.postinit()
yield user