mirror of
https://github.com/maubot/maubot
synced 2025-09-02 00:00:39 +00:00
Add command matching stuff
This commit is contained in:
parent
c79ed97a47
commit
0b246e44a8
7 changed files with 196 additions and 34 deletions
|
@ -13,33 +13,80 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional, Union, Callable
|
||||
from aiohttp import ClientSession
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from mautrix import Client as MatrixClient
|
||||
from mautrix.client import EventHandler
|
||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StateEvent, Membership,
|
||||
EventType)
|
||||
EventType, MessageEvent)
|
||||
|
||||
from .command_spec import ParsedCommand
|
||||
from .db import DBClient
|
||||
|
||||
log = logging.getLogger("maubot.client")
|
||||
|
||||
|
||||
class MaubotMatrixClient(MatrixClient):
|
||||
def __init__(self, maubot_client: 'Client', *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._maubot_client = maubot_client
|
||||
self.command_handlers: Dict[str, List[EventHandler]] = {}
|
||||
self.commands: List[ParsedCommand] = []
|
||||
|
||||
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
|
||||
|
||||
async def _command_event_handler(self, evt: MessageEvent) -> None:
|
||||
for command in self.commands:
|
||||
if command.match(evt):
|
||||
await self._trigger_command(command, evt)
|
||||
return
|
||||
|
||||
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
|
||||
for handler in self.command_handlers.get(command.name, []):
|
||||
await handler(evt)
|
||||
|
||||
def on(self, var: Union[EventHandler, EventType, str]
|
||||
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
|
||||
if isinstance(var, str):
|
||||
def decorator(func: EventHandler) -> EventHandler:
|
||||
self.add_command_handler(var, func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
return super().on(var)
|
||||
|
||||
def add_command_handler(self, command: str, handler: EventHandler) -> None:
|
||||
self.command_handlers.setdefault(command, []).append(handler)
|
||||
|
||||
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
|
||||
try:
|
||||
self.command_handlers[command].remove(handler)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class Client:
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
http_client: ClientSession = None
|
||||
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
|
||||
def __init__(self, db_instance: DBClient) -> None:
|
||||
self.db_instance: DBClient = db_instance
|
||||
self.db_instance = db_instance
|
||||
self.cache[self.id] = self
|
||||
self.client: MatrixClient = MatrixClient(mxid=self.id,
|
||||
base_url=self.homeserver,
|
||||
token=self.access_token,
|
||||
client_session=self.http_client,
|
||||
log=log.getChild(self.id))
|
||||
self.client = MaubotMatrixClient(maubot_client=self,
|
||||
store=self.db_instance,
|
||||
mxid=self.id,
|
||||
base_url=self.homeserver,
|
||||
token=self.access_token,
|
||||
client_session=self.http_client,
|
||||
log=log.getChild(self.id))
|
||||
if self.autojoin:
|
||||
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
|
||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||
|
||||
@classmethod
|
||||
def get(cls, user_id: UserID) -> Optional['Client']:
|
||||
|
@ -103,9 +150,9 @@ class Client:
|
|||
if value == self.db_instance.autojoin:
|
||||
return
|
||||
if value:
|
||||
self.client.add_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
|
||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||
else:
|
||||
self.client.remove_event_handler(self.handle_invite, EventType.ROOM_MEMBER)
|
||||
self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||
self.db_instance.autojoin = value
|
||||
|
||||
@property
|
||||
|
@ -126,6 +173,6 @@ class Client:
|
|||
|
||||
# endregion
|
||||
|
||||
async def handle_invite(self, evt: StateEvent) -> None:
|
||||
async def _handle_invite(self, evt: StateEvent) -> None:
|
||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
await self.client.join_room_by_id(evt.room_id)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue