Blacken and isort code

This commit is contained in:
Tulir Asokan 2022-03-25 14:22:37 +02:00
parent 6257979e7c
commit 068e268c63
97 changed files with 1781 additions and 1086 deletions

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,32 +13,46 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
import asyncio
import logging
from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter, DeviceID)
from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (
ContentURI,
DeviceID,
EventFilter,
EventType,
Filter,
FilterID,
Membership,
PresenceState,
RoomEventFilter,
RoomFilter,
StateEvent,
StateFilter,
StrippedStateEvent,
SyncToken,
UserID,
)
from .lib.store_proxy import SyncStoreProxy
from .db import DBClient
from .lib.store_proxy import SyncStoreProxy
from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, PgCryptoStore
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
from mautrix.util.async_db import Database as AsyncDatabase
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
crypto_import_error = None
except ImportError as e:
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
@ -46,8 +60,8 @@ except ImportError as e:
crypto_import_error = e
if TYPE_CHECKING:
from .instance import PluginInstance
from .config import Config
from .instance import PluginInstance
log = logging.getLogger("maubot.client")
@ -55,20 +69,20 @@ log = logging.getLogger("maubot.client")
class Client:
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
cache: dict[UserID, Client] = {}
http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
crypto_db: Optional['AsyncDatabase'] = None
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
crypto_db: AsyncDatabase | None = None
references: Set['PluginInstance']
references: set[PluginInstance]
db_instance: DBClient
client: MaubotMatrixClient
crypto: Optional['OlmMachine']
crypto_store: Optional['PgCryptoStore']
crypto: OlmMachine | None
crypto_store: PgCryptoStore | None
started: bool
remote_displayname: Optional[str]
remote_avatar_url: Optional[ContentURI]
remote_displayname: str | None
remote_avatar_url: ContentURI | None
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
@ -79,11 +93,17 @@ class Client:
self.sync_ok = True
self.remote_displayname = None
self.remote_avatar_url = None
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store)
self.client = MaubotMatrixClient(
mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=self.log,
loop=self.loop,
device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store,
)
if self.enable_crypto:
self._prepare_crypto()
else:
@ -104,8 +124,10 @@ class Client:
return False
elif not OlmMachine:
global crypto_import_error
self.log.warning("Client has device ID, but encryption dependencies not installed",
exc_info=crypto_import_error)
self.log.warning(
"Client has device ID, but encryption dependencies not installed",
exc_info=crypto_import_error,
)
# Clear the stack trace after it's logged once to avoid spamming logs
crypto_import_error = None
return False
@ -115,8 +137,9 @@ class Client:
return True
def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(account_id=self.id, pickle_key="mau.crypto",
db=self.crypto_db)
self.crypto_store = PgCryptoStore(
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
)
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
@ -133,13 +156,13 @@ class Client:
for event_type, func in handlers:
self.client.remove_event_handler(event_type, func)
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
async def handler(data: Dict[str, Any]) -> None:
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok
return handler
async def start(self, try_n: Optional[int] = 0) -> None:
async def start(self, try_n: int | None = 0) -> None:
try:
if try_n > 0:
await asyncio.sleep(try_n * 10)
@ -152,15 +175,16 @@ class Client:
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning("Mismatching device ID in crypto store and main database, "
"resetting encryption")
self.log.warning(
"Mismatching device ID in crypto store and main database, " "resetting encryption"
)
await self.crypto_store.delete()
crypto_device_id = None
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
async def _start(self, try_n: Optional[int] = 0) -> None:
async def _start(self, try_n: int | None = 0) -> None:
if not self.enabled:
self.log.debug("Not starting disabled client")
return
@ -179,8 +203,9 @@ class Client:
self.log.exception("Failed to get /account/whoami, disabling client")
self.db_instance.enabled = False
else:
self.log.warning(f"Failed to get /account/whoami, "
f"retrying in {(try_n + 1) * 10}s: {e}")
self.log.warning(
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
)
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
return
if whoami.user_id != self.id:
@ -188,25 +213,30 @@ class Client:
self.db_instance.enabled = False
return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(f"Device ID mismatch: expected {self.device_id}, "
f"but got {whoami.device_id}")
self.log.error(
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
)
self.db_instance.enabled = False
return
if not self.filter_id:
self.db_instance.edit(filter_id=await self.client.create_filter(Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
lazy_load_members=True,
),
state=StateFilter(
lazy_load_members=True,
self.db_instance.edit(
filter_id=await self.client.create_filter(
Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
lazy_load_members=True,
),
state=StateFilter(
lazy_load_members=True,
),
),
presence=EventFilter(
not_types=[EventType.PRESENCE],
),
)
),
presence=EventFilter(
not_types=[EventType.PRESENCE],
),
)))
)
)
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
@ -258,8 +288,9 @@ class Client:
"homeserver": self.homeserver,
"access_token": self.access_token,
"device_id": self.device_id,
"fingerprint": (self.crypto.account.fingerprint if self.crypto and self.crypto.account
else None),
"fingerprint": (
self.crypto.account.fingerprint if self.crypto and self.crypto.account else None
),
"enabled": self.enabled,
"started": self.started,
"sync": self.sync,
@ -274,7 +305,7 @@ class Client:
}
@classmethod
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
@ -284,7 +315,7 @@ class Client:
return Client(db_instance)
@classmethod
def all(cls) -> Iterable['Client']:
def all(cls) -> Iterable[Client]:
return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_tombstone(self, evt: StateEvent) -> None:
@ -324,8 +355,12 @@ class Client:
else:
await self._update_remote_profile()
async def update_access_details(self, access_token: Optional[str], homeserver: Optional[str],
device_id: Optional[str] = None) -> None:
async def update_access_details(
self,
access_token: str | None,
homeserver: str | None,
device_id: str | None = None,
) -> None:
if not access_token and not homeserver:
return
if device_id is None:
@ -338,10 +373,16 @@ class Client:
and device_id == self.device_id
):
return
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop,
device_id=device_id, client_session=self.http_client,
log=self.log, state_store=self.global_state_store)
new_client = MaubotMatrixClient(
mxid=self.id,
base_url=homeserver or self.homeserver,
token=access_token or self.access_token,
loop=self.loop,
device_id=device_id,
client_session=self.http_client,
log=self.log,
state_store=self.global_state_store,
)
whoami = await new_client.whoami()
if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {whoami.user_id}")
@ -455,7 +496,7 @@ class Client:
# endregion
def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop