mirror of
https://github.com/maubot/maubot
synced 2025-09-02 00:00:39 +00:00
Blacken and isort code
This commit is contained in:
parent
6257979e7c
commit
068e268c63
97 changed files with 1781 additions and 1086 deletions
165
maubot/client.py
165
maubot/client.py
|
@ -1,5 +1,5 @@
|
|||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2021 Tulir Asokan
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
|
@ -13,32 +13,46 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from mautrix.errors import MatrixInvalidToken
|
||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
|
||||
PresenceState, StateFilter, DeviceID)
|
||||
from mautrix.client import InternalEventType
|
||||
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
||||
from mautrix.errors import MatrixInvalidToken
|
||||
from mautrix.types import (
|
||||
ContentURI,
|
||||
DeviceID,
|
||||
EventFilter,
|
||||
EventType,
|
||||
Filter,
|
||||
FilterID,
|
||||
Membership,
|
||||
PresenceState,
|
||||
RoomEventFilter,
|
||||
RoomFilter,
|
||||
StateEvent,
|
||||
StateFilter,
|
||||
StrippedStateEvent,
|
||||
SyncToken,
|
||||
UserID,
|
||||
)
|
||||
|
||||
from .lib.store_proxy import SyncStoreProxy
|
||||
from .db import DBClient
|
||||
from .lib.store_proxy import SyncStoreProxy
|
||||
from .matrix import MaubotMatrixClient
|
||||
|
||||
try:
|
||||
from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, PgCryptoStore
|
||||
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
|
||||
from mautrix.util.async_db import Database as AsyncDatabase
|
||||
|
||||
|
||||
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
||||
pass
|
||||
|
||||
|
||||
crypto_import_error = None
|
||||
except ImportError as e:
|
||||
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
|
||||
|
@ -46,8 +60,8 @@ except ImportError as e:
|
|||
crypto_import_error = e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .instance import PluginInstance
|
||||
from .config import Config
|
||||
from .instance import PluginInstance
|
||||
|
||||
log = logging.getLogger("maubot.client")
|
||||
|
||||
|
@ -55,20 +69,20 @@ log = logging.getLogger("maubot.client")
|
|||
class Client:
|
||||
log: logging.Logger = None
|
||||
loop: asyncio.AbstractEventLoop = None
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
cache: dict[UserID, Client] = {}
|
||||
http_client: ClientSession = None
|
||||
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
|
||||
crypto_db: Optional['AsyncDatabase'] = None
|
||||
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
|
||||
crypto_db: AsyncDatabase | None = None
|
||||
|
||||
references: Set['PluginInstance']
|
||||
references: set[PluginInstance]
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
crypto: Optional['OlmMachine']
|
||||
crypto_store: Optional['PgCryptoStore']
|
||||
crypto: OlmMachine | None
|
||||
crypto_store: PgCryptoStore | None
|
||||
started: bool
|
||||
|
||||
remote_displayname: Optional[str]
|
||||
remote_avatar_url: Optional[ContentURI]
|
||||
remote_displayname: str | None
|
||||
remote_avatar_url: ContentURI | None
|
||||
|
||||
def __init__(self, db_instance: DBClient) -> None:
|
||||
self.db_instance = db_instance
|
||||
|
@ -79,11 +93,17 @@ class Client:
|
|||
self.sync_ok = True
|
||||
self.remote_displayname = None
|
||||
self.remote_avatar_url = None
|
||||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||
token=self.access_token, client_session=self.http_client,
|
||||
log=self.log, loop=self.loop, device_id=self.device_id,
|
||||
sync_store=SyncStoreProxy(self.db_instance),
|
||||
state_store=self.global_state_store)
|
||||
self.client = MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=self.homeserver,
|
||||
token=self.access_token,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
loop=self.loop,
|
||||
device_id=self.device_id,
|
||||
sync_store=SyncStoreProxy(self.db_instance),
|
||||
state_store=self.global_state_store,
|
||||
)
|
||||
if self.enable_crypto:
|
||||
self._prepare_crypto()
|
||||
else:
|
||||
|
@ -104,8 +124,10 @@ class Client:
|
|||
return False
|
||||
elif not OlmMachine:
|
||||
global crypto_import_error
|
||||
self.log.warning("Client has device ID, but encryption dependencies not installed",
|
||||
exc_info=crypto_import_error)
|
||||
self.log.warning(
|
||||
"Client has device ID, but encryption dependencies not installed",
|
||||
exc_info=crypto_import_error,
|
||||
)
|
||||
# Clear the stack trace after it's logged once to avoid spamming logs
|
||||
crypto_import_error = None
|
||||
return False
|
||||
|
@ -115,8 +137,9 @@ class Client:
|
|||
return True
|
||||
|
||||
def _prepare_crypto(self) -> None:
|
||||
self.crypto_store = PgCryptoStore(account_id=self.id, pickle_key="mau.crypto",
|
||||
db=self.crypto_db)
|
||||
self.crypto_store = PgCryptoStore(
|
||||
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
|
||||
)
|
||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||
self.client.crypto = self.crypto
|
||||
|
||||
|
@ -133,13 +156,13 @@ class Client:
|
|||
for event_type, func in handlers:
|
||||
self.client.remove_event_handler(event_type, func)
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: Dict[str, Any]) -> None:
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
||||
return handler
|
||||
|
||||
async def start(self, try_n: Optional[int] = 0) -> None:
|
||||
async def start(self, try_n: int | None = 0) -> None:
|
||||
try:
|
||||
if try_n > 0:
|
||||
await asyncio.sleep(try_n * 10)
|
||||
|
@ -152,15 +175,16 @@ class Client:
|
|||
await self.crypto_store.open()
|
||||
crypto_device_id = await self.crypto_store.get_device_id()
|
||||
if crypto_device_id and crypto_device_id != self.device_id:
|
||||
self.log.warning("Mismatching device ID in crypto store and main database, "
|
||||
"resetting encryption")
|
||||
self.log.warning(
|
||||
"Mismatching device ID in crypto store and main database, " "resetting encryption"
|
||||
)
|
||||
await self.crypto_store.delete()
|
||||
crypto_device_id = None
|
||||
await self.crypto.load()
|
||||
if not crypto_device_id:
|
||||
await self.crypto_store.put_device_id(self.device_id)
|
||||
|
||||
async def _start(self, try_n: Optional[int] = 0) -> None:
|
||||
async def _start(self, try_n: int | None = 0) -> None:
|
||||
if not self.enabled:
|
||||
self.log.debug("Not starting disabled client")
|
||||
return
|
||||
|
@ -179,8 +203,9 @@ class Client:
|
|||
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||
self.db_instance.enabled = False
|
||||
else:
|
||||
self.log.warning(f"Failed to get /account/whoami, "
|
||||
f"retrying in {(try_n + 1) * 10}s: {e}")
|
||||
self.log.warning(
|
||||
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
|
||||
)
|
||||
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
||||
return
|
||||
if whoami.user_id != self.id:
|
||||
|
@ -188,25 +213,30 @@ class Client:
|
|||
self.db_instance.enabled = False
|
||||
return
|
||||
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
|
||||
self.log.error(f"Device ID mismatch: expected {self.device_id}, "
|
||||
f"but got {whoami.device_id}")
|
||||
self.log.error(
|
||||
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
|
||||
)
|
||||
self.db_instance.enabled = False
|
||||
return
|
||||
if not self.filter_id:
|
||||
self.db_instance.edit(filter_id=await self.client.create_filter(Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
lazy_load_members=True,
|
||||
),
|
||||
state=StateFilter(
|
||||
lazy_load_members=True,
|
||||
self.db_instance.edit(
|
||||
filter_id=await self.client.create_filter(
|
||||
Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
lazy_load_members=True,
|
||||
),
|
||||
state=StateFilter(
|
||||
lazy_load_members=True,
|
||||
),
|
||||
),
|
||||
presence=EventFilter(
|
||||
not_types=[EventType.PRESENCE],
|
||||
),
|
||||
)
|
||||
),
|
||||
presence=EventFilter(
|
||||
not_types=[EventType.PRESENCE],
|
||||
),
|
||||
)))
|
||||
)
|
||||
)
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
|
@ -258,8 +288,9 @@ class Client:
|
|||
"homeserver": self.homeserver,
|
||||
"access_token": self.access_token,
|
||||
"device_id": self.device_id,
|
||||
"fingerprint": (self.crypto.account.fingerprint if self.crypto and self.crypto.account
|
||||
else None),
|
||||
"fingerprint": (
|
||||
self.crypto.account.fingerprint if self.crypto and self.crypto.account else None
|
||||
),
|
||||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"sync": self.sync,
|
||||
|
@ -274,7 +305,7 @@ class Client:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
|
||||
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
|
||||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
|
@ -284,7 +315,7 @@ class Client:
|
|||
return Client(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable['Client']:
|
||||
def all(cls) -> Iterable[Client]:
|
||||
return (cls.get(user.id, user) for user in DBClient.all())
|
||||
|
||||
async def _handle_tombstone(self, evt: StateEvent) -> None:
|
||||
|
@ -324,8 +355,12 @@ class Client:
|
|||
else:
|
||||
await self._update_remote_profile()
|
||||
|
||||
async def update_access_details(self, access_token: Optional[str], homeserver: Optional[str],
|
||||
device_id: Optional[str] = None) -> None:
|
||||
async def update_access_details(
|
||||
self,
|
||||
access_token: str | None,
|
||||
homeserver: str | None,
|
||||
device_id: str | None = None,
|
||||
) -> None:
|
||||
if not access_token and not homeserver:
|
||||
return
|
||||
if device_id is None:
|
||||
|
@ -338,10 +373,16 @@ class Client:
|
|||
and device_id == self.device_id
|
||||
):
|
||||
return
|
||||
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token, loop=self.loop,
|
||||
device_id=device_id, client_session=self.http_client,
|
||||
log=self.log, state_store=self.global_state_store)
|
||||
new_client = MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token,
|
||||
loop=self.loop,
|
||||
device_id=device_id,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
state_store=self.global_state_store,
|
||||
)
|
||||
whoami = await new_client.whoami()
|
||||
if whoami.user_id != self.id:
|
||||
raise ValueError(f"MXID mismatch: {whoami.user_id}")
|
||||
|
@ -455,7 +496,7 @@ class Client:
|
|||
# endregion
|
||||
|
||||
|
||||
def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue