Add command matching stuff

This commit is contained in:
Tulir Asokan 2018-10-16 00:25:23 +03:00
parent c79ed97a47
commit 0b246e44a8
7 changed files with 196 additions and 34 deletions

View file

@ -13,33 +13,80 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Optional
from typing import Dict, List, Optional, Union, Callable
from aiohttp import ClientSession
import asyncio
import logging
from mautrix import Client as MatrixClient
from mautrix.client import EventHandler
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
EventType)
EventType, MessageEvent)
from .command_spec import ParsedCommand
from .db import DBClient
log = logging.getLogger("maubot.client")
class MaubotMatrixClient(MatrixClient):
def __init__(self, maubot_client: 'Client', *args, **kwargs):
super().__init__(*args, **kwargs)
self._maubot_client = maubot_client
self.command_handlers: Dict[str, List[EventHandler]] = {}
self.commands: List[ParsedCommand] = []
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
async def _command_event_handler(self, evt: MessageEvent) -> None:
for command in self.commands:
if command.match(evt):
await self._trigger_command(command, evt)
return
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
for handler in self.command_handlers.get(command.name, []):
await handler(evt)
def on(self, var: Union[EventHandler, EventType, str]
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
if isinstance(var, str):
def decorator(func: EventHandler) -> EventHandler:
self.add_command_handler(var, func)
return func
return decorator
return super().on(var)
def add_command_handler(self, command: str, handler: EventHandler) -> None:
self.command_handlers.setdefault(command, []).append(handler)
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
try:
self.command_handlers[command].remove(handler)
except (KeyError, ValueError):
pass
class Client:
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
db_instance: DBClient
client: MaubotMatrixClient
def __init__(self, db_instance: DBClient) -> None:
self.db_instance: DBClient = db_instance
self.db_instance = db_instance
self.cache[self.id] = self
self.client: MatrixClient = MatrixClient(mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=log.getChild(self.id))
self.client = MaubotMatrixClient(maubot_client=self,
store=self.db_instance,
mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=log.getChild(self.id))
if self.autojoin:
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
@classmethod
def get(cls, user_id: UserID) -> Optional['Client']:
@ -103,9 +150,9 @@ class Client:
if value == self.db_instance.autojoin:
return
if value:
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
else:
self.client.remove_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
self.db_instance.autojoin = value
@property
@ -126,6 +173,6 @@ class Client:
# endregion
async def handle_invite(self, evt: StateEvent) -> None:
async def _handle_invite(self, evt: StateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room_by_id(evt.room_id)