mirror of
https://github.com/maubot/maubot
synced 2025-09-02 00:00:39 +00:00
parent
068e268c63
commit
21ed971d2f
43 changed files with 911 additions and 955 deletions
407
maubot/client.py
407
maubot/client.py
|
@ -15,14 +15,14 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from mautrix.client import InternalEventType
|
||||
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
||||
from mautrix.errors import MatrixInvalidToken
|
||||
from mautrix.types import (
|
||||
ContentURI,
|
||||
|
@ -41,69 +41,110 @@ from mautrix.types import (
|
|||
SyncToken,
|
||||
UserID,
|
||||
)
|
||||
from mautrix.util.async_getter_lock import async_getter_lock
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
||||
from .db import DBClient
|
||||
from .lib.store_proxy import SyncStoreProxy
|
||||
from .db import Client as DBClient
|
||||
from .matrix import MaubotMatrixClient
|
||||
|
||||
try:
|
||||
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
|
||||
from mautrix.util.async_db import Database as AsyncDatabase
|
||||
|
||||
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
||||
pass
|
||||
from mautrix.crypto import OlmMachine, PgCryptoStore
|
||||
|
||||
crypto_import_error = None
|
||||
except ImportError as e:
|
||||
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
|
||||
SQLStateStore = BaseSQLStateStore
|
||||
OlmMachine = PgCryptoStore = None
|
||||
crypto_import_error = e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
from .__main__ import Maubot
|
||||
from .instance import PluginInstance
|
||||
|
||||
log = logging.getLogger("maubot.client")
|
||||
|
||||
|
||||
class Client:
|
||||
log: logging.Logger = None
|
||||
loop: asyncio.AbstractEventLoop = None
|
||||
class Client(DBClient):
|
||||
maubot: "Maubot" = None
|
||||
cache: dict[UserID, Client] = {}
|
||||
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||
log: TraceLogger = logging.getLogger("maubot.client")
|
||||
|
||||
http_client: ClientSession = None
|
||||
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
|
||||
crypto_db: AsyncDatabase | None = None
|
||||
|
||||
references: set[PluginInstance]
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
crypto: OlmMachine | None
|
||||
crypto_store: PgCryptoStore | None
|
||||
started: bool
|
||||
sync_ok: bool
|
||||
|
||||
remote_displayname: str | None
|
||||
remote_avatar_url: ContentURI | None
|
||||
|
||||
def __init__(self, db_instance: DBClient) -> None:
|
||||
self.db_instance = db_instance
|
||||
def __init__(
|
||||
self,
|
||||
id: UserID,
|
||||
homeserver: str,
|
||||
access_token: str,
|
||||
device_id: DeviceID,
|
||||
enabled: bool = False,
|
||||
next_batch: SyncToken = "",
|
||||
filter_id: FilterID = "",
|
||||
sync: bool = True,
|
||||
autojoin: bool = True,
|
||||
online: bool = True,
|
||||
displayname: str = "disable",
|
||||
avatar_url: str = "disable",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
homeserver=homeserver,
|
||||
access_token=access_token,
|
||||
device_id=device_id,
|
||||
enabled=bool(enabled),
|
||||
next_batch=next_batch,
|
||||
filter_id=filter_id,
|
||||
sync=bool(sync),
|
||||
autojoin=bool(autojoin),
|
||||
online=bool(online),
|
||||
displayname=displayname,
|
||||
avatar_url=avatar_url,
|
||||
)
|
||||
self._postinited = False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
@classmethod
|
||||
def init_cls(cls, maubot: "Maubot") -> None:
|
||||
cls.maubot = maubot
|
||||
|
||||
def _make_client(
|
||||
self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
|
||||
) -> MaubotMatrixClient:
|
||||
return MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=homeserver or self.homeserver,
|
||||
token=token or self.access_token,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
crypto_log=self.log.getChild("crypto"),
|
||||
loop=self.maubot.loop,
|
||||
device_id=device_id or self.device_id,
|
||||
sync_store=self,
|
||||
state_store=self.maubot.state_store,
|
||||
)
|
||||
|
||||
def postinit(self) -> None:
|
||||
if self._postinited:
|
||||
raise RuntimeError("postinit() called twice")
|
||||
self._postinited = True
|
||||
self.cache[self.id] = self
|
||||
self.log = log.getChild(self.id)
|
||||
self.log = self.log.getChild(self.id)
|
||||
self.http_client = ClientSession(loop=self.maubot.loop)
|
||||
self.references = set()
|
||||
self.started = False
|
||||
self.sync_ok = True
|
||||
self.remote_displayname = None
|
||||
self.remote_avatar_url = None
|
||||
self.client = MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=self.homeserver,
|
||||
token=self.access_token,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
loop=self.loop,
|
||||
device_id=self.device_id,
|
||||
sync_store=SyncStoreProxy(self.db_instance),
|
||||
state_store=self.global_state_store,
|
||||
)
|
||||
self.client = self._make_client()
|
||||
if self.enable_crypto:
|
||||
self._prepare_crypto()
|
||||
else:
|
||||
|
@ -118,6 +159,12 @@ class Client:
|
|||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
||||
return handler
|
||||
|
||||
@property
|
||||
def enable_crypto(self) -> bool:
|
||||
if not self.device_id:
|
||||
|
@ -131,16 +178,21 @@ class Client:
|
|||
# Clear the stack trace after it's logged once to avoid spamming logs
|
||||
crypto_import_error = None
|
||||
return False
|
||||
elif not self.crypto_db:
|
||||
elif not self.maubot.crypto_db:
|
||||
self.log.warning("Client has device ID, but crypto database is not prepared")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _prepare_crypto(self) -> None:
|
||||
self.crypto_store = PgCryptoStore(
|
||||
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
|
||||
account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
|
||||
)
|
||||
self.crypto = OlmMachine(
|
||||
self.client,
|
||||
self.crypto_store,
|
||||
self.maubot.state_store,
|
||||
log=self.client.crypto_log,
|
||||
)
|
||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||
self.client.crypto = self.crypto
|
||||
|
||||
def _remove_crypto_event_handlers(self) -> None:
|
||||
|
@ -156,12 +208,6 @@ class Client:
|
|||
for event_type, func in handlers:
|
||||
self.client.remove_event_handler(event_type, func)
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
||||
return handler
|
||||
|
||||
async def start(self, try_n: int | None = 0) -> None:
|
||||
try:
|
||||
if try_n > 0:
|
||||
|
@ -196,47 +242,50 @@ class Client:
|
|||
whoami = await self.client.whoami()
|
||||
except MatrixInvalidToken as e:
|
||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
return
|
||||
except Exception as e:
|
||||
if try_n >= 8:
|
||||
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
else:
|
||||
self.log.warning(
|
||||
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
|
||||
f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}"
|
||||
)
|
||||
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
||||
_ = asyncio.create_task(self.start(try_n + 1))
|
||||
return
|
||||
if whoami.user_id != self.id:
|
||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
return
|
||||
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
|
||||
self.log.error(
|
||||
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
|
||||
)
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
return
|
||||
if not self.filter_id:
|
||||
self.db_instance.edit(
|
||||
filter_id=await self.client.create_filter(
|
||||
Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
lazy_load_members=True,
|
||||
),
|
||||
state=StateFilter(
|
||||
lazy_load_members=True,
|
||||
),
|
||||
self.filter_id = await self.client.create_filter(
|
||||
Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
lazy_load_members=True,
|
||||
),
|
||||
presence=EventFilter(
|
||||
not_types=[EventType.PRESENCE],
|
||||
state=StateFilter(
|
||||
lazy_load_members=True,
|
||||
),
|
||||
)
|
||||
),
|
||||
presence=EventFilter(
|
||||
not_types=[EventType.PRESENCE],
|
||||
),
|
||||
)
|
||||
)
|
||||
await self.update()
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
|
@ -270,18 +319,13 @@ class Client:
|
|||
if self.crypto:
|
||||
await self.crypto_store.close()
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
async def clear_cache(self) -> None:
|
||||
self.stop_sync()
|
||||
self.db_instance.edit(filter_id="", next_batch="")
|
||||
self.filter_id = FilterID("")
|
||||
self.next_batch = SyncToken("")
|
||||
await self.update()
|
||||
self.start_sync()
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db_instance.delete()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
|
@ -304,20 +348,6 @@ class Client:
|
|||
"instances": [instance.to_dict() for instance in self.references],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
|
||||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBClient.get(user_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return Client(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable[Client]:
|
||||
return (cls.get(user.id, user) for user in DBClient.all())
|
||||
|
||||
async def _handle_tombstone(self, evt: StateEvent) -> None:
|
||||
if not evt.content.replacement_room:
|
||||
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
|
||||
|
@ -329,7 +359,7 @@ class Client:
|
|||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
await self.client.join_room(evt.room_id)
|
||||
|
||||
async def update_started(self, started: bool) -> None:
|
||||
async def update_started(self, started: bool | None) -> None:
|
||||
if started is None or started == self.started:
|
||||
return
|
||||
if started:
|
||||
|
@ -337,23 +367,65 @@ class Client:
|
|||
else:
|
||||
await self.stop()
|
||||
|
||||
async def update_displayname(self, displayname: str) -> None:
|
||||
async def update_enabled(self, enabled: bool | None, save: bool = True) -> None:
|
||||
if enabled is None or enabled == self.enabled:
|
||||
return
|
||||
self.enabled = enabled
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
|
||||
if displayname is None or displayname == self.displayname:
|
||||
return
|
||||
self.db_instance.displayname = displayname
|
||||
self.displayname = displayname
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
else:
|
||||
await self._update_remote_profile()
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
|
||||
async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None:
|
||||
if avatar_url is None or avatar_url == self.avatar_url:
|
||||
return
|
||||
self.db_instance.avatar_url = avatar_url
|
||||
self.avatar_url = avatar_url
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
else:
|
||||
await self._update_remote_profile()
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_sync(self, sync: bool | None, save: bool = True) -> None:
|
||||
if sync is None or self.sync == sync:
|
||||
return
|
||||
self.sync = sync
|
||||
if self.started:
|
||||
if sync:
|
||||
self.start_sync()
|
||||
else:
|
||||
self.stop_sync()
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
|
||||
if autojoin is None or autojoin == self.autojoin:
|
||||
return
|
||||
if autojoin:
|
||||
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
else:
|
||||
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
self.autojoin = autojoin
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_online(self, online: bool | None, save: bool = True) -> None:
|
||||
if online is None or online == self.online:
|
||||
return
|
||||
self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
|
||||
self.online = online
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_access_details(
|
||||
self,
|
||||
|
@ -373,22 +445,13 @@ class Client:
|
|||
and device_id == self.device_id
|
||||
):
|
||||
return
|
||||
new_client = MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token,
|
||||
loop=self.loop,
|
||||
device_id=device_id,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
state_store=self.global_state_store,
|
||||
)
|
||||
new_client = self._make_client(homeserver, access_token, device_id)
|
||||
whoami = await new_client.whoami()
|
||||
if whoami.user_id != self.id:
|
||||
raise ValueError(f"MXID mismatch: {whoami.user_id}")
|
||||
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
|
||||
new_client.sync_store = SyncStoreProxy(self.db_instance)
|
||||
new_client.sync_store = self
|
||||
self.stop_sync()
|
||||
|
||||
# TODO this event handler transfer is pretty hacky
|
||||
|
@ -398,9 +461,9 @@ class Client:
|
|||
new_client.global_event_handlers = self.client.global_event_handlers
|
||||
|
||||
self.client = new_client
|
||||
self.db_instance.homeserver = homeserver
|
||||
self.db_instance.access_token = access_token
|
||||
self.db_instance.device_id = device_id
|
||||
self.homeserver = homeserver
|
||||
self.access_token = access_token
|
||||
self.device_id = device_id
|
||||
if self.enable_crypto:
|
||||
self._prepare_crypto()
|
||||
await self._start_crypto()
|
||||
|
@ -413,97 +476,53 @@ class Client:
|
|||
profile = await self.client.get_profile(self.id)
|
||||
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
|
||||
|
||||
# region Properties
|
||||
async def delete(self) -> None:
|
||||
try:
|
||||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
await super().delete()
|
||||
|
||||
@property
|
||||
def id(self) -> UserID:
|
||||
return self.db_instance.id
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get(
|
||||
cls,
|
||||
user_id: UserID,
|
||||
*,
|
||||
homeserver: str | None = None,
|
||||
access_token: str | None = None,
|
||||
device_id: DeviceID | None = None,
|
||||
) -> Client | None:
|
||||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def homeserver(self) -> str:
|
||||
return self.db_instance.homeserver
|
||||
user = cast(cls, await super().get(user_id))
|
||||
if user is not None:
|
||||
user.postinit()
|
||||
return user
|
||||
|
||||
@property
|
||||
def access_token(self) -> str:
|
||||
return self.db_instance.access_token
|
||||
if homeserver and access_token:
|
||||
user = cls(
|
||||
user_id,
|
||||
homeserver=homeserver,
|
||||
access_token=access_token,
|
||||
device_id=device_id or "",
|
||||
)
|
||||
await user.insert()
|
||||
user.postinit()
|
||||
return user
|
||||
|
||||
@property
|
||||
def device_id(self) -> DeviceID:
|
||||
return self.db_instance.device_id
|
||||
return None
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
|
||||
@enabled.setter
|
||||
def enabled(self, value: bool) -> None:
|
||||
self.db_instance.enabled = value
|
||||
|
||||
@property
|
||||
def next_batch(self) -> SyncToken:
|
||||
return self.db_instance.next_batch
|
||||
|
||||
@property
|
||||
def filter_id(self) -> FilterID:
|
||||
return self.db_instance.filter_id
|
||||
|
||||
@property
|
||||
def sync(self) -> bool:
|
||||
return self.db_instance.sync
|
||||
|
||||
@sync.setter
|
||||
def sync(self, value: bool) -> None:
|
||||
if value == self.db_instance.sync:
|
||||
return
|
||||
self.db_instance.sync = value
|
||||
if self.started:
|
||||
if value:
|
||||
self.start_sync()
|
||||
else:
|
||||
self.stop_sync()
|
||||
|
||||
@property
|
||||
def autojoin(self) -> bool:
|
||||
return self.db_instance.autojoin
|
||||
|
||||
@autojoin.setter
|
||||
def autojoin(self, value: bool) -> None:
|
||||
if value == self.db_instance.autojoin:
|
||||
return
|
||||
if value:
|
||||
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
else:
|
||||
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
self.db_instance.autojoin = value
|
||||
|
||||
@property
|
||||
def online(self) -> bool:
|
||||
return self.db_instance.online
|
||||
|
||||
@online.setter
|
||||
def online(self, value: bool) -> None:
|
||||
self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE
|
||||
self.db_instance.online = value
|
||||
|
||||
@property
|
||||
def displayname(self) -> str:
|
||||
return self.db_instance.displayname
|
||||
|
||||
@property
|
||||
def avatar_url(self) -> ContentURI:
|
||||
return self.db_instance.avatar_url
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
|
||||
if OlmMachine:
|
||||
db_url = config["crypto_database"]
|
||||
if db_url == "default":
|
||||
db_url = config["database"]
|
||||
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
|
||||
|
||||
return Client.all()
|
||||
@classmethod
|
||||
async def all(cls) -> AsyncGenerator[Client, None]:
|
||||
users = await super().all()
|
||||
user: cls
|
||||
for user in users:
|
||||
try:
|
||||
yield cls.cache[user.id]
|
||||
except KeyError:
|
||||
user.postinit()
|
||||
yield user
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue