Refactor things and implement instance API

This commit is contained in:
Tulir Asokan 2018-11-01 18:11:54 +02:00
parent cbeff0c0cb
commit bc87b2a02b
14 changed files with 249 additions and 100 deletions

View file

@ -14,10 +14,12 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Set, TYPE_CHECKING
from aiohttp import ClientSession
import asyncio
import logging
from sqlalchemy.orm import Session
from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
EventType, Filter, RoomFilter, RoomEventFilter)
@ -32,6 +34,7 @@ log = logging.getLogger("maubot.client")
class Client:
db: Session = None
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
@ -73,12 +76,12 @@ class Client:
user_id = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
self.enabled = False
self.db_instance.enabled = False
return
except MatrixRequestError:
if try_n >= 5:
self.log.exception("Failed to get /account/whoami, disabling client")
self.enabled = False
self.db_instance.enabled = False
else:
self.log.exception(f"Failed to get /account/whoami, "
f"retrying in {(try_n + 1) * 10}s")
@ -86,7 +89,7 @@ class Client:
return
if user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
self.enabled = False
self.db_instance.enabled = False
return
if not self.filter_id:
self.filter_id = await self.client.create_filter(Filter(
@ -100,8 +103,7 @@ class Client:
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
if self.sync:
self.client.start(self.filter_id)
self.start_sync()
self.started = True
self.log.info("Client started, starting plugin instances...")
await self.start_plugins()
@ -110,12 +112,19 @@ class Client:
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
async def stop_plugins(self) -> None:
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started],
loop=self.loop)
def start_sync(self) -> None:
if self.sync:
self.client.start(self.filter_id)
def stop_sync(self) -> None:
self.client.stop()
def stop(self) -> None:
self.started = False
self.client.stop()
self.stop_sync()
def to_dict(self) -> dict:
return {
@ -233,7 +242,8 @@ class Client:
# endregion
def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
Client.db = db
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
return Client.all()