mirror of
https://github.com/maubot/maubot
synced 2025-09-02 00:00:39 +00:00
Refactor how plugins are started and update spec
This commit is contained in:
parent
b96d6e6a94
commit
9e066478a9
10 changed files with 160 additions and 79 deletions
106
maubot/client.py
106
maubot/client.py
|
@ -18,6 +18,7 @@ from aiohttp import ClientSession
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||
EventType, Filter, RoomFilter, RoomEventFilter)
|
||||
|
||||
|
@ -31,6 +32,7 @@ log = logging.getLogger("maubot.client")
|
|||
|
||||
|
||||
class Client:
|
||||
log: logging.Logger
|
||||
loop: asyncio.AbstractEventLoop
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
http_client: ClientSession = None
|
||||
|
@ -38,42 +40,97 @@ class Client:
|
|||
references: Set['PluginInstance']
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
started: bool
|
||||
|
||||
def __init__(self, db_instance: DBClient) -> None:
|
||||
self.db_instance = db_instance
|
||||
self.cache[self.id] = self
|
||||
self.log = log.getChild(self.id)
|
||||
self.references = set()
|
||||
self.started = False
|
||||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||
token=self.access_token, client_session=self.http_client,
|
||||
log=self.log, loop=self.loop, store=self.db_instance)
|
||||
if self.autojoin:
|
||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||
|
||||
def start(self) -> None:
|
||||
asyncio.ensure_future(self._start(), loop=self.loop)
|
||||
|
||||
async def _start(self) -> None:
|
||||
async def start(self, try_n: Optional[int] = 0) -> None:
|
||||
try:
|
||||
if not self.filter_id:
|
||||
self.filter_id = await self.client.create_filter(Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
),
|
||||
),
|
||||
))
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
await self.client.start(self.filter_id)
|
||||
if try_n > 0:
|
||||
await asyncio.sleep(try_n * 10)
|
||||
await self._start(try_n)
|
||||
except Exception:
|
||||
self.log.exception("starting raised exception")
|
||||
self.log.exception("Failed to start")
|
||||
|
||||
async def _start(self, try_n: Optional[int] = 0) -> None:
|
||||
if not self.enabled:
|
||||
self.log.debug("Not starting disabled client")
|
||||
return
|
||||
elif self.started:
|
||||
self.log.warning("Ignoring start() call to started client")
|
||||
return
|
||||
try:
|
||||
user_id = await self.client.whoami()
|
||||
except MatrixInvalidToken as e:
|
||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||
self.enabled = False
|
||||
return
|
||||
except MatrixRequestError:
|
||||
if try_n >= 5:
|
||||
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||
self.enabled = False
|
||||
else:
|
||||
self.log.exception(f"Failed to get /account/whoami, "
|
||||
f"retrying in {(try_n + 1) * 10}s")
|
||||
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
||||
return
|
||||
if user_id != self.id:
|
||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
|
||||
self.enabled = False
|
||||
return
|
||||
if not self.filter_id:
|
||||
self.filter_id = await self.client.create_filter(Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
),
|
||||
),
|
||||
))
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
if self.sync:
|
||||
self.client.start(self.filter_id)
|
||||
self.started = True
|
||||
self.log.info("Client started, starting plugin instances...")
|
||||
await self.start_plugins()
|
||||
|
||||
async def start_plugins(self) -> None:
|
||||
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
|
||||
|
||||
async def stop_plugins(self) -> None:
|
||||
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
|
||||
loop=self.loop)
|
||||
|
||||
def stop(self) -> None:
|
||||
self.started = False
|
||||
self.client.stop()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"homeserver": self.homeserver,
|
||||
"access_token": self.access_token,
|
||||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"sync": self.sync,
|
||||
"autojoin": self.autojoin,
|
||||
"displayname": self.displayname,
|
||||
"avatar_url": self.avatar_url,
|
||||
"instances": [instance.to_dict() for instance in self.references],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
|
||||
try:
|
||||
|
@ -111,6 +168,14 @@ class Client:
|
|||
self.client.api.token = value
|
||||
self.db_instance.access_token = value
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
|
||||
@enabled.setter
|
||||
def enabled(self, value: bool) -> None:
|
||||
self.db_instance.enabled = value
|
||||
|
||||
@property
|
||||
def next_batch(self) -> SyncToken:
|
||||
return self.db_instance.next_batch
|
||||
|
@ -168,8 +233,7 @@ class Client:
|
|||
# endregion
|
||||
|
||||
|
||||
def init(loop: asyncio.AbstractEventLoop) -> None:
|
||||
def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
for client in Client.all():
|
||||
client.start()
|
||||
return Client.all()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue