More changes

This commit is contained in:
Tulir Asokan 2018-10-16 16:41:02 +03:00
parent 0b246e44a8
commit eef052b1e9
9 changed files with 195 additions and 61 deletions

View file

@ -13,62 +13,21 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Union, Callable
from typing import Dict, List, Optional
from aiohttp import ClientSession
import asyncio
import logging
from mautrix import Client as MatrixClient
from mautrix.client import EventHandler
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
EventType, MessageEvent)
from mautrix.types import UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership, EventType
from .command_spec import ParsedCommand
from .db import DBClient
from .matrix import MaubotMatrixClient
log = logging.getLogger("maubot.client")
class MaubotMatrixClient(MatrixClient):
def __init__(self, maubot_client: 'Client', *args, **kwargs):
super().__init__(*args, **kwargs)
self._maubot_client = maubot_client
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
async def _command_event_handler(self, evt: MessageEvent) -> None:
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass
class Client:
loop: asyncio.AbstractEventLoop
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
@ -78,26 +37,33 @@ class Client:
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
self.cache[self.id] = self
self.client = MaubotMatrixClient(maubot_client=self,
store=self.db_instance,
mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
self.client = MaubotMatrixClient(maubot_client=self, store=self.db_instance,
mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=log.getChild(self.id))
if self.autojoin:
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
def start(self) -> None:
asyncio.ensure_future(self.client.start(), loop=self.loop)
def stop(self) -> None:
self.client.stop()
@classmethod
def get(cls, user_id: UserID) -> Optional['Client']:
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
try:
return cls.cache[user_id]
except KeyError:
db_instance = DBClient.query.get(user_id)
db_instance = db_instance or DBClient.query.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()]
# region Properties
@property
@ -176,3 +142,9 @@ class Client:
async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)
def init(loop: asyncio.AbstractEventLoop) -> None:
Client.loop = loop
for client in Client.all():
client.start()