Refactor how plugins are started and update spec

This commit is contained in:
Tulir Asokan 2018-11-01 01:51:54 +02:00
parent b96d6e6a94
commit 9e066478a9
10 changed files with 160 additions and 79 deletions

View file

@ -18,6 +18,7 @@ from aiohttp import ClientSession
import asyncio
import logging
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
EventType, Filter, RoomFilter, RoomEventFilter)
@ -31,6 +32,7 @@ log = logging.getLogger("maubot.client")
class Client:
log: logging.Logger
loop: asyncio.AbstractEventLoop
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
@ -38,42 +40,97 @@ class Client:
references: Set['PluginInstance']
db_instance: DBClient
client: MaubotMatrixClient
started: bool
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
self.cache[self.id] = self
self.log = log.getChild(self.id)
self.references = set()
self.started = False
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, store=self.db_instance)
if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
def start(self) -> None:
asyncio.ensure_future(self._start(), loop=self.loop)
async def _start(self) -> None:
async def start(self, try_n: Optional[int] = 0) -> None:
try:
if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
),
),
))
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
await self.client.start(self.filter_id)
if try_n > 0:
await asyncio.sleep(try_n * 10)
await self._start(try_n)
except Exception:
self.log.exception("starting raised exception")
self.log.exception("Failed to start")
async def _start(self, try_n: Optional[int] = 0) -> None:
if not self.enabled:
self.log.debug("Not starting disabled client")
return
elif self.started:
self.log.warning("Ignoring start() call to started client")
return
try:
user_id = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
self.enabled = False
return
except MatrixRequestError:
if try_n >= 5:
self.log.exception("Failed to get /account/whoami, disabling client")
self.enabled = False
else:
self.log.exception(f"Failed to get /account/whoami, "
f"retrying in {(try_n + 1) * 10}s")
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
return
if user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
self.enabled = False
return
if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
),
),
))
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
if self.sync:
self.client.start(self.filter_id)
self.started = True
self.log.info("Client started, starting plugin instances...")
await self.start_plugins()
async def start_plugins(self) -> None:
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
async def stop_plugins(self) -> None:
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
loop=self.loop)
def stop(self) -> None:
self.started = False
self.client.stop()
def to_dict(self) -> dict:
return {
"id": self.id,
"homeserver": self.homeserver,
"access_token": self.access_token,
"enabled": self.enabled,
"started": self.started,
"sync": self.sync,
"autojoin": self.autojoin,
"displayname": self.displayname,
"avatar_url": self.avatar_url,
"instances": [instance.to_dict() for instance in self.references],
}
@classmethod
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
try:
@ -111,6 +168,14 @@ class Client:
self.client.api.token = value
self.db_instance.access_token = value
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@enabled.setter
def enabled(self, value: bool) -> None:
self.db_instance.enabled = value
@property
def next_batch(self) -> SyncToken:
return self.db_instance.next_batch
@ -168,8 +233,7 @@ class Client:
# endregion
def init(loop: asyncio.AbstractEventLoop) -> None:
def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
for client in Client.all():
client.start()
return Client.all()