mirror of
https://github.com/maubot/maubot
synced 2025-09-07 12:50:37 +00:00
parent
068e268c63
commit
21ed971d2f
43 changed files with 911 additions and 955 deletions
|
@ -13,24 +13,37 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
import asyncio
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from mautrix.util.async_db import Database, DatabaseException
|
||||
from mautrix.util.program import Program
|
||||
|
||||
from .__meta__ import __version__
|
||||
from .client import Client, init as init_client_class
|
||||
from .client import Client
|
||||
from .config import Config
|
||||
from .db import init as init_db
|
||||
from .instance import init as init_plugin_instance_class
|
||||
from .db import init as init_db, upgrade_table
|
||||
from .instance import PluginInstance
|
||||
from .lib.future_awaitable import FutureAwaitable
|
||||
from .lib.state_store import PgStateStore
|
||||
from .loader.zip import init as init_zip_loader
|
||||
from .management.api import init as init_mgmt_api
|
||||
from .server import MaubotServer
|
||||
|
||||
try:
|
||||
from mautrix.crypto.store import PgCryptoStore
|
||||
except ImportError:
|
||||
PgCryptoStore = None
|
||||
|
||||
|
||||
class Maubot(Program):
|
||||
config: Config
|
||||
server: MaubotServer
|
||||
db: Database
|
||||
crypto_db: Database | None
|
||||
state_store: PgStateStore
|
||||
|
||||
config_class = Config
|
||||
module = "maubot"
|
||||
|
@ -45,6 +58,19 @@ class Maubot(Program):
|
|||
init(self.loop)
|
||||
self.add_shutdown_actions(FutureAwaitable(stop_all))
|
||||
|
||||
def prepare_arg_parser(self) -> None:
|
||||
super().prepare_arg_parser()
|
||||
self.parser.add_argument(
|
||||
"--ignore-unsupported-database",
|
||||
action="store_true",
|
||||
help="Run even if the database schema is too new",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--ignore-foreign-tables",
|
||||
action="store_true",
|
||||
help="Run even if the database contains tables from other programs (like Synapse)",
|
||||
)
|
||||
|
||||
def prepare(self) -> None:
|
||||
super().prepare()
|
||||
|
||||
|
@ -52,21 +78,59 @@ class Maubot(Program):
|
|||
self.prepare_log_websocket()
|
||||
|
||||
init_zip_loader(self.config)
|
||||
init_db(self.config)
|
||||
clients = init_client_class(self.config, self.loop)
|
||||
self.add_startup_actions(*(client.start() for client in clients))
|
||||
self.db = Database.create(
|
||||
self.config["database"],
|
||||
upgrade_table=upgrade_table,
|
||||
db_args=self.config["database_opts"],
|
||||
owner_name=self.name,
|
||||
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||
)
|
||||
init_db(self.db)
|
||||
if self.config["crypto_database"] == "default":
|
||||
self.crypto_db = self.db
|
||||
else:
|
||||
self.crypto_db = Database.create(
|
||||
self.config["crypto_database"],
|
||||
upgrade_table=PgCryptoStore.upgrade_table,
|
||||
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||
)
|
||||
Client.init_cls(self)
|
||||
PluginInstance.init_cls(self)
|
||||
management_api = init_mgmt_api(self.config, self.loop)
|
||||
self.server = MaubotServer(management_api, self.config, self.loop)
|
||||
self.state_store = PgStateStore(self.db)
|
||||
|
||||
plugins = init_plugin_instance_class(self.config, self.server, self.loop)
|
||||
for plugin in plugins:
|
||||
plugin.load()
|
||||
async def start_db(self) -> None:
|
||||
self.log.debug("Starting database...")
|
||||
ignore_unsupported = self.args.ignore_unsupported_database
|
||||
self.db.upgrade_table.allow_unsupported = ignore_unsupported
|
||||
self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
|
||||
PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported
|
||||
try:
|
||||
await self.db.start()
|
||||
await self.state_store.upgrade_table.upgrade(self.db)
|
||||
if self.crypto_db and self.crypto_db is not self.db:
|
||||
await self.crypto_db.start()
|
||||
else:
|
||||
await PgCryptoStore.upgrade_table.upgrade(self.db)
|
||||
except DatabaseException as e:
|
||||
self.log.critical("Failed to initialize database", exc_info=e)
|
||||
if e.explanation:
|
||||
self.log.info(e.explanation)
|
||||
sys.exit(25)
|
||||
|
||||
async def system_exit(self) -> None:
|
||||
if hasattr(self, "db"):
|
||||
self.log.trace("Stopping database due to SystemExit")
|
||||
await self.db.stop()
|
||||
|
||||
async def start(self) -> None:
|
||||
if Client.crypto_db:
|
||||
self.log.debug("Starting client crypto database")
|
||||
await Client.crypto_db.start()
|
||||
await self.start_db()
|
||||
await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()])
|
||||
await asyncio.gather(*[client.start() async for client in Client.all()])
|
||||
await super().start()
|
||||
async for plugin in PluginInstance.all():
|
||||
await plugin.load()
|
||||
await self.server.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
|
@ -77,6 +141,7 @@ class Maubot(Program):
|
|||
await asyncio.wait_for(self.server.stop(), 5)
|
||||
except asyncio.TimeoutError:
|
||||
self.log.warning("Stopping server timed out")
|
||||
await self.db.stop()
|
||||
|
||||
|
||||
Maubot().run()
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = "0.2.1"
|
||||
__version__ = "0.3.0+dev"
|
||||
|
|
|
@ -38,7 +38,7 @@ def logs(server: str, tail: int) -> None:
|
|||
global history_count
|
||||
history_count = tail
|
||||
loop = asyncio.get_event_loop()
|
||||
future = asyncio.ensure_future(view_logs(server, token), loop=loop)
|
||||
future = asyncio.create_task(view_logs(server, token), loop=loop)
|
||||
try:
|
||||
loop.run_until_complete(future)
|
||||
except KeyboardInterrupt:
|
||||
|
|
407
maubot/client.py
407
maubot/client.py
|
@ -15,14 +15,14 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from mautrix.client import InternalEventType
|
||||
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
||||
from mautrix.errors import MatrixInvalidToken
|
||||
from mautrix.types import (
|
||||
ContentURI,
|
||||
|
@ -41,69 +41,110 @@ from mautrix.types import (
|
|||
SyncToken,
|
||||
UserID,
|
||||
)
|
||||
from mautrix.util.async_getter_lock import async_getter_lock
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
||||
from .db import DBClient
|
||||
from .lib.store_proxy import SyncStoreProxy
|
||||
from .db import Client as DBClient
|
||||
from .matrix import MaubotMatrixClient
|
||||
|
||||
try:
|
||||
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
|
||||
from mautrix.util.async_db import Database as AsyncDatabase
|
||||
|
||||
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
||||
pass
|
||||
from mautrix.crypto import OlmMachine, PgCryptoStore
|
||||
|
||||
crypto_import_error = None
|
||||
except ImportError as e:
|
||||
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
|
||||
SQLStateStore = BaseSQLStateStore
|
||||
OlmMachine = PgCryptoStore = None
|
||||
crypto_import_error = e
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import Config
|
||||
from .__main__ import Maubot
|
||||
from .instance import PluginInstance
|
||||
|
||||
log = logging.getLogger("maubot.client")
|
||||
|
||||
|
||||
class Client:
|
||||
log: logging.Logger = None
|
||||
loop: asyncio.AbstractEventLoop = None
|
||||
class Client(DBClient):
|
||||
maubot: "Maubot" = None
|
||||
cache: dict[UserID, Client] = {}
|
||||
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||
log: TraceLogger = logging.getLogger("maubot.client")
|
||||
|
||||
http_client: ClientSession = None
|
||||
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
|
||||
crypto_db: AsyncDatabase | None = None
|
||||
|
||||
references: set[PluginInstance]
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
crypto: OlmMachine | None
|
||||
crypto_store: PgCryptoStore | None
|
||||
started: bool
|
||||
sync_ok: bool
|
||||
|
||||
remote_displayname: str | None
|
||||
remote_avatar_url: ContentURI | None
|
||||
|
||||
def __init__(self, db_instance: DBClient) -> None:
|
||||
self.db_instance = db_instance
|
||||
def __init__(
|
||||
self,
|
||||
id: UserID,
|
||||
homeserver: str,
|
||||
access_token: str,
|
||||
device_id: DeviceID,
|
||||
enabled: bool = False,
|
||||
next_batch: SyncToken = "",
|
||||
filter_id: FilterID = "",
|
||||
sync: bool = True,
|
||||
autojoin: bool = True,
|
||||
online: bool = True,
|
||||
displayname: str = "disable",
|
||||
avatar_url: str = "disable",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
homeserver=homeserver,
|
||||
access_token=access_token,
|
||||
device_id=device_id,
|
||||
enabled=bool(enabled),
|
||||
next_batch=next_batch,
|
||||
filter_id=filter_id,
|
||||
sync=bool(sync),
|
||||
autojoin=bool(autojoin),
|
||||
online=bool(online),
|
||||
displayname=displayname,
|
||||
avatar_url=avatar_url,
|
||||
)
|
||||
self._postinited = False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
@classmethod
|
||||
def init_cls(cls, maubot: "Maubot") -> None:
|
||||
cls.maubot = maubot
|
||||
|
||||
def _make_client(
|
||||
self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
|
||||
) -> MaubotMatrixClient:
|
||||
return MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=homeserver or self.homeserver,
|
||||
token=token or self.access_token,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
crypto_log=self.log.getChild("crypto"),
|
||||
loop=self.maubot.loop,
|
||||
device_id=device_id or self.device_id,
|
||||
sync_store=self,
|
||||
state_store=self.maubot.state_store,
|
||||
)
|
||||
|
||||
def postinit(self) -> None:
|
||||
if self._postinited:
|
||||
raise RuntimeError("postinit() called twice")
|
||||
self._postinited = True
|
||||
self.cache[self.id] = self
|
||||
self.log = log.getChild(self.id)
|
||||
self.log = self.log.getChild(self.id)
|
||||
self.http_client = ClientSession(loop=self.maubot.loop)
|
||||
self.references = set()
|
||||
self.started = False
|
||||
self.sync_ok = True
|
||||
self.remote_displayname = None
|
||||
self.remote_avatar_url = None
|
||||
self.client = MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=self.homeserver,
|
||||
token=self.access_token,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
loop=self.loop,
|
||||
device_id=self.device_id,
|
||||
sync_store=SyncStoreProxy(self.db_instance),
|
||||
state_store=self.global_state_store,
|
||||
)
|
||||
self.client = self._make_client()
|
||||
if self.enable_crypto:
|
||||
self._prepare_crypto()
|
||||
else:
|
||||
|
@ -118,6 +159,12 @@ class Client:
|
|||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
||||
return handler
|
||||
|
||||
@property
|
||||
def enable_crypto(self) -> bool:
|
||||
if not self.device_id:
|
||||
|
@ -131,16 +178,21 @@ class Client:
|
|||
# Clear the stack trace after it's logged once to avoid spamming logs
|
||||
crypto_import_error = None
|
||||
return False
|
||||
elif not self.crypto_db:
|
||||
elif not self.maubot.crypto_db:
|
||||
self.log.warning("Client has device ID, but crypto database is not prepared")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _prepare_crypto(self) -> None:
|
||||
self.crypto_store = PgCryptoStore(
|
||||
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
|
||||
account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
|
||||
)
|
||||
self.crypto = OlmMachine(
|
||||
self.client,
|
||||
self.crypto_store,
|
||||
self.maubot.state_store,
|
||||
log=self.client.crypto_log,
|
||||
)
|
||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||
self.client.crypto = self.crypto
|
||||
|
||||
def _remove_crypto_event_handlers(self) -> None:
|
||||
|
@ -156,12 +208,6 @@ class Client:
|
|||
for event_type, func in handlers:
|
||||
self.client.remove_event_handler(event_type, func)
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
||||
return handler
|
||||
|
||||
async def start(self, try_n: int | None = 0) -> None:
|
||||
try:
|
||||
if try_n > 0:
|
||||
|
@ -196,47 +242,50 @@ class Client:
|
|||
whoami = await self.client.whoami()
|
||||
except MatrixInvalidToken as e:
|
||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
return
|
||||
except Exception as e:
|
||||
if try_n >= 8:
|
||||
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
else:
|
||||
self.log.warning(
|
||||
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
|
||||
f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}"
|
||||
)
|
||||
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
|
||||
_ = asyncio.create_task(self.start(try_n + 1))
|
||||
return
|
||||
if whoami.user_id != self.id:
|
||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
return
|
||||
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
|
||||
self.log.error(
|
||||
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
|
||||
)
|
||||
self.db_instance.enabled = False
|
||||
self.enabled = False
|
||||
await self.update()
|
||||
return
|
||||
if not self.filter_id:
|
||||
self.db_instance.edit(
|
||||
filter_id=await self.client.create_filter(
|
||||
Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
lazy_load_members=True,
|
||||
),
|
||||
state=StateFilter(
|
||||
lazy_load_members=True,
|
||||
),
|
||||
self.filter_id = await self.client.create_filter(
|
||||
Filter(
|
||||
room=RoomFilter(
|
||||
timeline=RoomEventFilter(
|
||||
limit=50,
|
||||
lazy_load_members=True,
|
||||
),
|
||||
presence=EventFilter(
|
||||
not_types=[EventType.PRESENCE],
|
||||
state=StateFilter(
|
||||
lazy_load_members=True,
|
||||
),
|
||||
)
|
||||
),
|
||||
presence=EventFilter(
|
||||
not_types=[EventType.PRESENCE],
|
||||
),
|
||||
)
|
||||
)
|
||||
await self.update()
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
|
@ -270,18 +319,13 @@ class Client:
|
|||
if self.crypto:
|
||||
await self.crypto_store.close()
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
async def clear_cache(self) -> None:
|
||||
self.stop_sync()
|
||||
self.db_instance.edit(filter_id="", next_batch="")
|
||||
self.filter_id = FilterID("")
|
||||
self.next_batch = SyncToken("")
|
||||
await self.update()
|
||||
self.start_sync()
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db_instance.delete()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
|
@ -304,20 +348,6 @@ class Client:
|
|||
"instances": [instance.to_dict() for instance in self.references],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
|
||||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBClient.get(user_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return Client(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable[Client]:
|
||||
return (cls.get(user.id, user) for user in DBClient.all())
|
||||
|
||||
async def _handle_tombstone(self, evt: StateEvent) -> None:
|
||||
if not evt.content.replacement_room:
|
||||
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
|
||||
|
@ -329,7 +359,7 @@ class Client:
|
|||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
await self.client.join_room(evt.room_id)
|
||||
|
||||
async def update_started(self, started: bool) -> None:
|
||||
async def update_started(self, started: bool | None) -> None:
|
||||
if started is None or started == self.started:
|
||||
return
|
||||
if started:
|
||||
|
@ -337,23 +367,65 @@ class Client:
|
|||
else:
|
||||
await self.stop()
|
||||
|
||||
async def update_displayname(self, displayname: str) -> None:
|
||||
async def update_enabled(self, enabled: bool | None, save: bool = True) -> None:
|
||||
if enabled is None or enabled == self.enabled:
|
||||
return
|
||||
self.enabled = enabled
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
|
||||
if displayname is None or displayname == self.displayname:
|
||||
return
|
||||
self.db_instance.displayname = displayname
|
||||
self.displayname = displayname
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
else:
|
||||
await self._update_remote_profile()
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_avatar_url(self, avatar_url: ContentURI) -> None:
|
||||
async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None:
|
||||
if avatar_url is None or avatar_url == self.avatar_url:
|
||||
return
|
||||
self.db_instance.avatar_url = avatar_url
|
||||
self.avatar_url = avatar_url
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
else:
|
||||
await self._update_remote_profile()
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_sync(self, sync: bool | None, save: bool = True) -> None:
|
||||
if sync is None or self.sync == sync:
|
||||
return
|
||||
self.sync = sync
|
||||
if self.started:
|
||||
if sync:
|
||||
self.start_sync()
|
||||
else:
|
||||
self.stop_sync()
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
|
||||
if autojoin is None or autojoin == self.autojoin:
|
||||
return
|
||||
if autojoin:
|
||||
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
else:
|
||||
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
self.autojoin = autojoin
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_online(self, online: bool | None, save: bool = True) -> None:
|
||||
if online is None or online == self.online:
|
||||
return
|
||||
self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
|
||||
self.online = online
|
||||
if save:
|
||||
await self.update()
|
||||
|
||||
async def update_access_details(
|
||||
self,
|
||||
|
@ -373,22 +445,13 @@ class Client:
|
|||
and device_id == self.device_id
|
||||
):
|
||||
return
|
||||
new_client = MaubotMatrixClient(
|
||||
mxid=self.id,
|
||||
base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token,
|
||||
loop=self.loop,
|
||||
device_id=device_id,
|
||||
client_session=self.http_client,
|
||||
log=self.log,
|
||||
state_store=self.global_state_store,
|
||||
)
|
||||
new_client = self._make_client(homeserver, access_token, device_id)
|
||||
whoami = await new_client.whoami()
|
||||
if whoami.user_id != self.id:
|
||||
raise ValueError(f"MXID mismatch: {whoami.user_id}")
|
||||
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
|
||||
new_client.sync_store = SyncStoreProxy(self.db_instance)
|
||||
new_client.sync_store = self
|
||||
self.stop_sync()
|
||||
|
||||
# TODO this event handler transfer is pretty hacky
|
||||
|
@ -398,9 +461,9 @@ class Client:
|
|||
new_client.global_event_handlers = self.client.global_event_handlers
|
||||
|
||||
self.client = new_client
|
||||
self.db_instance.homeserver = homeserver
|
||||
self.db_instance.access_token = access_token
|
||||
self.db_instance.device_id = device_id
|
||||
self.homeserver = homeserver
|
||||
self.access_token = access_token
|
||||
self.device_id = device_id
|
||||
if self.enable_crypto:
|
||||
self._prepare_crypto()
|
||||
await self._start_crypto()
|
||||
|
@ -413,97 +476,53 @@ class Client:
|
|||
profile = await self.client.get_profile(self.id)
|
||||
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
|
||||
|
||||
# region Properties
|
||||
async def delete(self) -> None:
|
||||
try:
|
||||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
await super().delete()
|
||||
|
||||
@property
|
||||
def id(self) -> UserID:
|
||||
return self.db_instance.id
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get(
|
||||
cls,
|
||||
user_id: UserID,
|
||||
*,
|
||||
homeserver: str | None = None,
|
||||
access_token: str | None = None,
|
||||
device_id: DeviceID | None = None,
|
||||
) -> Client | None:
|
||||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def homeserver(self) -> str:
|
||||
return self.db_instance.homeserver
|
||||
user = cast(cls, await super().get(user_id))
|
||||
if user is not None:
|
||||
user.postinit()
|
||||
return user
|
||||
|
||||
@property
|
||||
def access_token(self) -> str:
|
||||
return self.db_instance.access_token
|
||||
if homeserver and access_token:
|
||||
user = cls(
|
||||
user_id,
|
||||
homeserver=homeserver,
|
||||
access_token=access_token,
|
||||
device_id=device_id or "",
|
||||
)
|
||||
await user.insert()
|
||||
user.postinit()
|
||||
return user
|
||||
|
||||
@property
|
||||
def device_id(self) -> DeviceID:
|
||||
return self.db_instance.device_id
|
||||
return None
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
|
||||
@enabled.setter
|
||||
def enabled(self, value: bool) -> None:
|
||||
self.db_instance.enabled = value
|
||||
|
||||
@property
|
||||
def next_batch(self) -> SyncToken:
|
||||
return self.db_instance.next_batch
|
||||
|
||||
@property
|
||||
def filter_id(self) -> FilterID:
|
||||
return self.db_instance.filter_id
|
||||
|
||||
@property
|
||||
def sync(self) -> bool:
|
||||
return self.db_instance.sync
|
||||
|
||||
@sync.setter
|
||||
def sync(self, value: bool) -> None:
|
||||
if value == self.db_instance.sync:
|
||||
return
|
||||
self.db_instance.sync = value
|
||||
if self.started:
|
||||
if value:
|
||||
self.start_sync()
|
||||
else:
|
||||
self.stop_sync()
|
||||
|
||||
@property
|
||||
def autojoin(self) -> bool:
|
||||
return self.db_instance.autojoin
|
||||
|
||||
@autojoin.setter
|
||||
def autojoin(self, value: bool) -> None:
|
||||
if value == self.db_instance.autojoin:
|
||||
return
|
||||
if value:
|
||||
self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
else:
|
||||
self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
|
||||
self.db_instance.autojoin = value
|
||||
|
||||
@property
|
||||
def online(self) -> bool:
|
||||
return self.db_instance.online
|
||||
|
||||
@online.setter
|
||||
def online(self, value: bool) -> None:
|
||||
self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE
|
||||
self.db_instance.online = value
|
||||
|
||||
@property
|
||||
def displayname(self) -> str:
|
||||
return self.db_instance.displayname
|
||||
|
||||
@property
|
||||
def avatar_url(self) -> ContentURI:
|
||||
return self.db_instance.avatar_url
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
|
||||
if OlmMachine:
|
||||
db_url = config["crypto_database"]
|
||||
if db_url == "default":
|
||||
db_url = config["database"]
|
||||
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
|
||||
|
||||
return Client.all()
|
||||
@classmethod
|
||||
async def all(cls) -> AsyncGenerator[Client, None]:
|
||||
users = await super().all()
|
||||
user: cls
|
||||
for user in users:
|
||||
try:
|
||||
yield cls.cache[user.id]
|
||||
except KeyError:
|
||||
user.postinit()
|
||||
yield user
|
||||
|
|
108
maubot/db.py
108
maubot/db.py
|
@ -1,108 +0,0 @@
|
|||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Iterable, Optional
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, String, Text
|
||||
from sqlalchemy.engine.base import Engine
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
|
||||
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
|
||||
from mautrix.util.db import Base
|
||||
|
||||
from .config import Config
|
||||
|
||||
|
||||
class DBPlugin(Base):
|
||||
__tablename__ = "plugin"
|
||||
|
||||
id: str = Column(String(255), primary_key=True)
|
||||
type: str = Column(String(255), nullable=False)
|
||||
enabled: bool = Column(Boolean, nullable=False, default=False)
|
||||
primary_user: UserID = Column(
|
||||
String(255),
|
||||
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
config: str = Column(Text, nullable=False, default="")
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable["DBPlugin"]:
|
||||
return cls._select_all()
|
||||
|
||||
@classmethod
|
||||
def get(cls, id: str) -> Optional["DBPlugin"]:
|
||||
return cls._select_one_or_none(cls.c.id == id)
|
||||
|
||||
|
||||
class DBClient(Base):
|
||||
__tablename__ = "client"
|
||||
|
||||
id: UserID = Column(String(255), primary_key=True)
|
||||
homeserver: str = Column(String(255), nullable=False)
|
||||
access_token: str = Column(Text, nullable=False)
|
||||
device_id: DeviceID = Column(String(255), nullable=True)
|
||||
enabled: bool = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
next_batch: SyncToken = Column(String(255), nullable=False, default="")
|
||||
filter_id: FilterID = Column(String(255), nullable=False, default="")
|
||||
|
||||
sync: bool = Column(Boolean, nullable=False, default=True)
|
||||
autojoin: bool = Column(Boolean, nullable=False, default=True)
|
||||
online: bool = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
displayname: str = Column(String(255), nullable=False, default="")
|
||||
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable["DBClient"]:
|
||||
return cls._select_all()
|
||||
|
||||
@classmethod
|
||||
def get(cls, id: str) -> Optional["DBClient"]:
|
||||
return cls._select_one_or_none(cls.c.id == id)
|
||||
|
||||
|
||||
def init(config: Config) -> Engine:
|
||||
db = sql.create_engine(config["database"])
|
||||
Base.metadata.bind = db
|
||||
|
||||
for table in (DBPlugin, DBClient, RoomState, UserProfile):
|
||||
table.bind(db)
|
||||
|
||||
if not db.has_table("alembic_version"):
|
||||
log = logging.getLogger("maubot.db")
|
||||
|
||||
if db.has_table("client") and db.has_table("plugin"):
|
||||
log.warning(
|
||||
"alembic_version table not found, but client and plugin tables found. "
|
||||
"Assuming pre-Alembic database and inserting version."
|
||||
)
|
||||
db.execute(
|
||||
"CREATE TABLE IF NOT EXISTS alembic_version ("
|
||||
" version_num VARCHAR(32) PRIMARY KEY"
|
||||
");"
|
||||
)
|
||||
db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');")
|
||||
else:
|
||||
log.critical(
|
||||
"alembic_version table not found. " "Did you forget to `alembic upgrade head`?"
|
||||
)
|
||||
sys.exit(10)
|
||||
|
||||
return db
|
13
maubot/db/__init__.py
Normal file
13
maubot/db/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from mautrix.util.async_db import Database
|
||||
|
||||
from .client import Client
|
||||
from .instance import Instance
|
||||
from .upgrade import upgrade_table
|
||||
|
||||
|
||||
def init(db: Database) -> None:
|
||||
for table in (Client, Instance):
|
||||
table.db = db
|
||||
|
||||
|
||||
__all__ = ["upgrade_table", "init", "Client", "Instance"]
|
114
maubot/db/client.py
Normal file
114
maubot/db/client.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.client import SyncStore
|
||||
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Client(SyncStore):
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
id: UserID
|
||||
homeserver: str
|
||||
access_token: str
|
||||
device_id: DeviceID
|
||||
enabled: bool
|
||||
|
||||
next_batch: SyncToken
|
||||
filter_id: FilterID
|
||||
|
||||
sync: bool
|
||||
autojoin: bool
|
||||
online: bool
|
||||
|
||||
displayname: str
|
||||
avatar_url: ContentURI
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record | None) -> Client | None:
|
||||
if row is None:
|
||||
return None
|
||||
return cls(**row)
|
||||
|
||||
_columns = (
|
||||
"id, homeserver, access_token, device_id, enabled, next_batch, filter_id, "
|
||||
"sync, autojoin, online, displayname, avatar_url"
|
||||
)
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return (
|
||||
self.id,
|
||||
self.homeserver,
|
||||
self.access_token,
|
||||
self.device_id,
|
||||
self.enabled,
|
||||
self.next_batch,
|
||||
self.filter_id,
|
||||
self.sync,
|
||||
self.autojoin,
|
||||
self.online,
|
||||
self.displayname,
|
||||
self.avatar_url,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def all(cls) -> list[Client]:
|
||||
rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client")
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def get(cls, id: str) -> Client | None:
|
||||
q = f"SELECT {cls._columns} FROM client WHERE id=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, id))
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = """
|
||||
INSERT INTO client (
|
||||
id, homeserver, access_token, device_id, enabled, next_batch, filter_id,
|
||||
sync, autojoin, online, displayname, avatar_url
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
"""
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
||||
await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id)
|
||||
self.next_batch = next_batch
|
||||
|
||||
async def get_next_batch(self) -> SyncToken:
|
||||
return self.next_batch
|
||||
|
||||
async def update(self) -> None:
|
||||
q = """
|
||||
UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5,
|
||||
next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10,
|
||||
displayname=$11, avatar_url=$12
|
||||
WHERE id=$1
|
||||
"""
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def delete(self) -> None:
|
||||
await self.db.execute("DELETE FROM client WHERE id=$1", self.id)
|
75
maubot/db/instance.py
Normal file
75
maubot/db/instance.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from asyncpg import Record
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
fake_db = Database.create("") if TYPE_CHECKING else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Instance:
|
||||
db: ClassVar[Database] = fake_db
|
||||
|
||||
id: str
|
||||
type: str
|
||||
enabled: bool
|
||||
primary_user: UserID
|
||||
config_str: str
|
||||
|
||||
@classmethod
|
||||
def _from_row(cls, row: Record | None) -> Instance | None:
|
||||
if row is None:
|
||||
return None
|
||||
return cls(**row)
|
||||
|
||||
@classmethod
|
||||
async def all(cls) -> list[Instance]:
|
||||
rows = await cls.db.fetch("SELECT id, type, enabled, primary_user, config FROM instance")
|
||||
return [cls._from_row(row) for row in rows]
|
||||
|
||||
@classmethod
|
||||
async def get(cls, id: str) -> Instance | None:
|
||||
q = "SELECT id, type, enabled, primary_user, config FROM instance WHERE id=$1"
|
||||
return cls._from_row(await cls.db.fetchrow(q, id))
|
||||
|
||||
async def update_id(self, new_id: str) -> None:
|
||||
await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id)
|
||||
self.id = new_id
|
||||
|
||||
@property
|
||||
def _values(self):
|
||||
return self.id, self.type, self.enabled, self.primary_user, self.config_str
|
||||
|
||||
async def insert(self) -> None:
|
||||
q = (
|
||||
"INSERT INTO instance (id, type, enabled, primary_user, config) "
|
||||
"VALUES ($1, $2, $3, $4, $5)"
|
||||
)
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def update(self) -> None:
|
||||
q = "UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5 WHERE id=$1"
|
||||
await self.db.execute(q, *self._values)
|
||||
|
||||
async def delete(self) -> None:
|
||||
await self.db.execute("DELETE FROM instance WHERE id=$1", self.id)
|
5
maubot/db/upgrade/__init__.py
Normal file
5
maubot/db/upgrade/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
from mautrix.util.async_db import UpgradeTable
|
||||
|
||||
upgrade_table = UpgradeTable()
|
||||
|
||||
from . import v01_initial_revision
|
136
maubot/db/upgrade/v01_initial_revision.py
Normal file
136
maubot/db/upgrade/v01_initial_revision.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from mautrix.util.async_db import Connection, Scheme
|
||||
|
||||
from . import upgrade_table
|
||||
|
||||
legacy_version_query = "SELECT version_num FROM alembic_version"
|
||||
last_legacy_version = "90aa88820eab"
|
||||
|
||||
|
||||
@upgrade_table.register(description="Initial asyncpg revision")
|
||||
async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
|
||||
if await conn.table_exists("alembic_version"):
|
||||
await migrate_legacy_to_v1(conn, scheme)
|
||||
else:
|
||||
return await create_v1_tables(conn)
|
||||
|
||||
|
||||
async def create_v1_tables(conn: Connection) -> None:
|
||||
await conn.execute(
|
||||
"""CREATE TABLE client (
|
||||
id TEXT PRIMARY KEY,
|
||||
homeserver TEXT NOT NULL,
|
||||
access_token TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL,
|
||||
|
||||
next_batch TEXT NOT NULL,
|
||||
filter_id TEXT NOT NULL,
|
||||
|
||||
sync BOOLEAN NOT NULL,
|
||||
autojoin BOOLEAN NOT NULL,
|
||||
online BOOLEAN NOT NULL,
|
||||
|
||||
displayname TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""CREATE TABLE instance (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL,
|
||||
primary_user TEXT NOT NULL,
|
||||
config TEXT NOT NULL,
|
||||
FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
)"""
|
||||
)
|
||||
|
||||
|
||||
async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None:
|
||||
legacy_version = await conn.fetchval(legacy_version_query)
|
||||
if legacy_version != last_legacy_version:
|
||||
raise RuntimeError(
|
||||
"Legacy database is not on last version. "
|
||||
"Please upgrade the old database with alembic or drop it completely first."
|
||||
)
|
||||
await conn.execute("ALTER TABLE plugin RENAME TO instance")
|
||||
await update_state_store(conn, scheme)
|
||||
if scheme != Scheme.SQLITE:
|
||||
await varchar_to_text(conn)
|
||||
await conn.execute("DROP TABLE alembic_version")
|
||||
|
||||
|
||||
async def update_state_store(conn: Connection, scheme: Scheme) -> None:
|
||||
# The Matrix state store already has more or less the correct schema, so set the version
|
||||
await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)")
|
||||
await conn.execute("INSERT INTO mx_version (version) VALUES (2)")
|
||||
if scheme != Scheme.SQLITE:
|
||||
# Remove old uppercase membership type and recreate it as lowercase
|
||||
await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT")
|
||||
await conn.execute("DROP TYPE IF EXISTS membership")
|
||||
await conn.execute(
|
||||
"CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
|
||||
)
|
||||
await conn.execute(
|
||||
"ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership "
|
||||
"USING LOWER(membership)::membership"
|
||||
)
|
||||
else:
|
||||
# Recreate table to remove CHECK constraint and lowercase everything
|
||||
await conn.execute(
|
||||
"""CREATE TABLE new_mx_user_profile (
|
||||
room_id TEXT,
|
||||
user_id TEXT,
|
||||
membership TEXT NOT NULL
|
||||
CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')),
|
||||
displayname TEXT,
|
||||
avatar_url TEXT,
|
||||
PRIMARY KEY (room_id, user_id)
|
||||
)"""
|
||||
)
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url)
|
||||
SELECT room_id, user_id, LOWER(membership), displayname, avatar_url
|
||||
FROM mx_user_profile
|
||||
"""
|
||||
)
|
||||
await conn.execute("DROP TABLE mx_user_profile")
|
||||
await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile")
|
||||
|
||||
|
||||
async def varchar_to_text(conn: Connection) -> None:
|
||||
columns_to_adjust = {
|
||||
"client": (
|
||||
"id",
|
||||
"homeserver",
|
||||
"device_id",
|
||||
"next_batch",
|
||||
"filter_id",
|
||||
"displayname",
|
||||
"avatar_url",
|
||||
),
|
||||
"instance": ("id", "type", "primary_user"),
|
||||
"mx_room_state": ("room_id",),
|
||||
"mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"),
|
||||
}
|
||||
for table, columns in columns_to_adjust.items():
|
||||
for column in columns:
|
||||
await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT')
|
|
@ -6,9 +6,7 @@
|
|||
database: sqlite:///maubot.db
|
||||
|
||||
# Separate database URL for the crypto database. "default" means use the same database as above.
|
||||
# Due to concurrency issues, you should use a separate file when using SQLite rather than the same as above.
|
||||
# When using postgres, using the same database for both is safe.
|
||||
crypto_database: sqlite:///crypto.db
|
||||
crypto_database: default
|
||||
|
||||
plugin_directories:
|
||||
# The directory where uploaded new plugins should be stored.
|
||||
|
|
|
@ -15,8 +15,10 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
from asyncio import AbstractEventLoop
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import os.path
|
||||
|
@ -26,16 +28,17 @@ from ruamel.yaml.comments import CommentedMap
|
|||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.util.async_getter_lock import async_getter_lock
|
||||
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
||||
|
||||
from .client import Client
|
||||
from .config import Config
|
||||
from .db import DBPlugin
|
||||
from .db import Instance as DBInstance
|
||||
from .loader import PluginLoader, ZippedPluginLoader
|
||||
from .plugin_base import Plugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import MaubotServer, PluginWebApp
|
||||
from .__main__ import Maubot
|
||||
from .server import PluginWebApp
|
||||
|
||||
log = logging.getLogger("maubot.instance")
|
||||
|
||||
|
@ -44,29 +47,42 @@ yaml.indent(4)
|
|||
yaml.width = 200
|
||||
|
||||
|
||||
class PluginInstance:
|
||||
webserver: MaubotServer = None
|
||||
mb_config: Config = None
|
||||
loop: AbstractEventLoop = None
|
||||
class PluginInstance(DBInstance):
|
||||
maubot: "Maubot" = None
|
||||
cache: dict[str, PluginInstance] = {}
|
||||
plugin_directories: list[str] = []
|
||||
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||
|
||||
log: logging.Logger
|
||||
loader: PluginLoader
|
||||
client: Client
|
||||
plugin: Plugin
|
||||
config: BaseProxyConfig
|
||||
loader: PluginLoader | None
|
||||
client: Client | None
|
||||
plugin: Plugin | None
|
||||
config: BaseProxyConfig | None
|
||||
base_cfg: RecursiveDict[CommentedMap] | None
|
||||
base_cfg_str: str | None
|
||||
inst_db: sql.engine.Engine
|
||||
inst_db_tables: dict[str, sql.Table]
|
||||
inst_db: sql.engine.Engine | None
|
||||
inst_db_tables: dict[str, sql.Table] | None
|
||||
inst_webapp: PluginWebApp | None
|
||||
inst_webapp_url: str | None
|
||||
started: bool
|
||||
|
||||
def __init__(self, db_instance: DBPlugin):
|
||||
self.db_instance = db_instance
|
||||
def __init__(
|
||||
self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = ""
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
@classmethod
|
||||
def init_cls(cls, maubot: "Maubot") -> None:
|
||||
cls.maubot = maubot
|
||||
|
||||
def postinit(self) -> None:
|
||||
self.log = log.getChild(self.id)
|
||||
self.cache[self.id] = self
|
||||
self.config = None
|
||||
self.started = False
|
||||
self.loader = None
|
||||
|
@ -78,7 +94,6 @@ class PluginInstance:
|
|||
self.inst_webapp_url = None
|
||||
self.base_cfg = None
|
||||
self.base_cfg_str = None
|
||||
self.cache[self.id] = self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
@ -87,10 +102,10 @@ class PluginInstance:
|
|||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"primary_user": self.primary_user,
|
||||
"config": self.db_instance.config,
|
||||
"config": self.config_str,
|
||||
"base_config": self.base_cfg_str,
|
||||
"database": (
|
||||
self.inst_db is not None and self.mb_config["api_features.instance_database"]
|
||||
self.inst_db is not None and self.maubot.config["api_features.instance_database"]
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -101,19 +116,19 @@ class PluginInstance:
|
|||
self.inst_db_tables = metadata.tables
|
||||
return self.inst_db_tables
|
||||
|
||||
def load(self) -> bool:
|
||||
async def load(self) -> bool:
|
||||
if not self.loader:
|
||||
try:
|
||||
self.loader = PluginLoader.find(self.type)
|
||||
except KeyError:
|
||||
self.log.error(f"Failed to find loader for type {self.type}")
|
||||
self.db_instance.enabled = False
|
||||
await self.update_enabled(False)
|
||||
return False
|
||||
if not self.client:
|
||||
self.client = Client.get(self.primary_user)
|
||||
self.client = await Client.get(self.primary_user)
|
||||
if not self.client:
|
||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||
self.db_instance.enabled = False
|
||||
await self.update_enabled(False)
|
||||
return False
|
||||
if self.loader.meta.database:
|
||||
self.enable_database()
|
||||
|
@ -125,18 +140,18 @@ class PluginInstance:
|
|||
return True
|
||||
|
||||
def enable_webapp(self) -> None:
|
||||
self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id)
|
||||
self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
|
||||
|
||||
def disable_webapp(self) -> None:
|
||||
self.webserver.remove_instance_webapp(self.id)
|
||||
self.maubot.server.remove_instance_webapp(self.id)
|
||||
self.inst_webapp = None
|
||||
self.inst_webapp_url = None
|
||||
|
||||
def enable_database(self) -> None:
|
||||
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
|
||||
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
|
||||
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
|
||||
|
||||
def delete(self) -> None:
|
||||
async def delete(self) -> None:
|
||||
if self.loader is not None:
|
||||
self.loader.references.remove(self)
|
||||
if self.client is not None:
|
||||
|
@ -145,23 +160,23 @@ class PluginInstance:
|
|||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db_instance.delete()
|
||||
await super().delete()
|
||||
if self.inst_db:
|
||||
self.inst_db.dispose()
|
||||
ZippedPluginLoader.trash(
|
||||
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
|
||||
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
|
||||
reason="deleted",
|
||||
)
|
||||
if self.inst_webapp:
|
||||
self.disable_webapp()
|
||||
|
||||
def load_config(self) -> CommentedMap:
|
||||
return yaml.load(self.db_instance.config)
|
||||
return yaml.load(self.config_str)
|
||||
|
||||
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
|
||||
buf = io.StringIO()
|
||||
yaml.dump(data, buf)
|
||||
self.db_instance.config = buf.getvalue()
|
||||
self.config_str = buf.getvalue()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self.started:
|
||||
|
@ -172,7 +187,7 @@ class PluginInstance:
|
|||
return
|
||||
if not self.client or not self.loader:
|
||||
self.log.warning("Missing plugin instance dependencies, attempting to load...")
|
||||
if not self.load():
|
||||
if not await self.load():
|
||||
return
|
||||
cls = await self.loader.load()
|
||||
if self.loader.meta.webapp and self.inst_webapp is None:
|
||||
|
@ -205,7 +220,7 @@ class PluginInstance:
|
|||
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
|
||||
self.plugin = cls(
|
||||
client=self.client.client,
|
||||
loop=self.loop,
|
||||
loop=self.maubot.loop,
|
||||
http=self.client.http_client,
|
||||
instance_id=self.id,
|
||||
log=self.log,
|
||||
|
@ -219,7 +234,7 @@ class PluginInstance:
|
|||
await self.plugin.internal_start()
|
||||
except Exception:
|
||||
self.log.exception("Failed to start instance")
|
||||
self.db_instance.enabled = False
|
||||
await self.update_enabled(False)
|
||||
return
|
||||
self.started = True
|
||||
self.inst_db_tables = None
|
||||
|
@ -241,60 +256,51 @@ class PluginInstance:
|
|||
self.plugin = None
|
||||
self.inst_db_tables = None
|
||||
|
||||
@classmethod
|
||||
def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None:
|
||||
try:
|
||||
return cls.cache[instance_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBPlugin.get(instance_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return PluginInstance(db_instance)
|
||||
async def update_id(self, new_id: str | None) -> None:
|
||||
if new_id is not None and new_id.lower() != self.id:
|
||||
await super().update_id(new_id.lower())
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable[PluginInstance]:
|
||||
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
|
||||
|
||||
def update_id(self, new_id: str) -> None:
|
||||
if new_id is not None and new_id != self.id:
|
||||
self.db_instance.id = new_id.lower()
|
||||
|
||||
def update_config(self, config: str) -> None:
|
||||
if not config or self.db_instance.config == config:
|
||||
async def update_config(self, config: str | None) -> None:
|
||||
if config is None or self.config_str == config:
|
||||
return
|
||||
self.db_instance.config = config
|
||||
self.config_str = config
|
||||
if self.started and self.plugin is not None:
|
||||
self.plugin.on_external_config_update()
|
||||
res = self.plugin.on_external_config_update()
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
await self.update()
|
||||
|
||||
async def update_primary_user(self, primary_user: UserID) -> bool:
|
||||
if not primary_user or primary_user == self.primary_user:
|
||||
async def update_primary_user(self, primary_user: UserID | None) -> bool:
|
||||
if primary_user is None or primary_user == self.primary_user:
|
||||
return True
|
||||
client = Client.get(primary_user)
|
||||
client = await Client.get(primary_user)
|
||||
if not client:
|
||||
return False
|
||||
await self.stop()
|
||||
self.db_instance.primary_user = client.id
|
||||
self.primary_user = client.id
|
||||
if self.client:
|
||||
self.client.references.remove(self)
|
||||
self.client = client
|
||||
self.client.references.add(self)
|
||||
await self.update()
|
||||
await self.start()
|
||||
self.log.debug(f"Primary user switched to {self.client.id}")
|
||||
return True
|
||||
|
||||
async def update_type(self, type: str) -> bool:
|
||||
if not type or type == self.type:
|
||||
async def update_type(self, type: str | None) -> bool:
|
||||
if type is None or type == self.type:
|
||||
return True
|
||||
try:
|
||||
loader = PluginLoader.find(type)
|
||||
except KeyError:
|
||||
return False
|
||||
await self.stop()
|
||||
self.db_instance.type = loader.meta.id
|
||||
self.type = loader.meta.id
|
||||
if self.loader:
|
||||
self.loader.references.remove(self)
|
||||
self.loader = loader
|
||||
self.loader.references.add(self)
|
||||
await self.update()
|
||||
await self.start()
|
||||
self.log.debug(f"Type switched to {self.loader.meta.id}")
|
||||
return True
|
||||
|
@ -303,39 +309,41 @@ class PluginInstance:
|
|||
if started is not None and started != self.started:
|
||||
await (self.start() if started else self.stop())
|
||||
|
||||
def update_enabled(self, enabled: bool) -> None:
|
||||
async def update_enabled(self, enabled: bool) -> None:
|
||||
if enabled is not None and enabled != self.enabled:
|
||||
self.db_instance.enabled = enabled
|
||||
self.enabled = enabled
|
||||
await self.update()
|
||||
|
||||
# region Properties
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get(
|
||||
cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
|
||||
) -> PluginInstance | None:
|
||||
try:
|
||||
return cls.cache[instance_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.db_instance.id
|
||||
instance = cast(cls, await super().get(instance_id))
|
||||
if instance is not None:
|
||||
instance.postinit()
|
||||
return instance
|
||||
|
||||
@id.setter
|
||||
def id(self, value: str) -> None:
|
||||
self.db_instance.id = value
|
||||
if type and primary_user:
|
||||
instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user)
|
||||
await instance.insert()
|
||||
instance.postinit()
|
||||
return instance
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.db_instance.type
|
||||
return None
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
|
||||
@property
|
||||
def primary_user(self) -> UserID:
|
||||
return self.db_instance.primary_user
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def init(
|
||||
config: Config, webserver: MaubotServer, loop: AbstractEventLoop
|
||||
) -> Iterable[PluginInstance]:
|
||||
PluginInstance.mb_config = config
|
||||
PluginInstance.loop = loop
|
||||
PluginInstance.webserver = webserver
|
||||
return PluginInstance.all()
|
||||
@classmethod
|
||||
async def all(cls) -> AsyncGenerator[PluginInstance, None]:
|
||||
instances = await super().all()
|
||||
instance: PluginInstance
|
||||
for instance in instances:
|
||||
try:
|
||||
yield cls.cache[instance.id]
|
||||
except KeyError:
|
||||
instance.postinit()
|
||||
yield instance
|
||||
|
|
|
@ -28,14 +28,19 @@ LOADER_COLOR = PREFIX + "36m" # blue
|
|||
class ColorFormatter(BaseColorFormatter):
|
||||
def _color_name(self, module: str) -> str:
|
||||
client = "maubot.client"
|
||||
if module.startswith(client):
|
||||
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
|
||||
if module.startswith(client + "."):
|
||||
suffix = ""
|
||||
if module.endswith(".crypto"):
|
||||
suffix = f".{MAU_COLOR}crypto{RESET}"
|
||||
module = module[: -len(".crypto")]
|
||||
module = module[len(client) + 1 :]
|
||||
return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}"
|
||||
instance = "maubot.instance"
|
||||
if module.startswith(instance):
|
||||
if module.startswith(instance + "."):
|
||||
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
|
||||
loader = "maubot.loader"
|
||||
if module.startswith(loader):
|
||||
if module.startswith(loader + "."):
|
||||
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
|
||||
if module.startswith("maubot"):
|
||||
if module.startswith("maubot."):
|
||||
return f"{MAU_COLOR}{module}{RESET}"
|
||||
return super()._color_name(module)
|
||||
|
|
|
@ -13,16 +13,15 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from mautrix.client import SyncStore
|
||||
from mautrix.types import SyncToken
|
||||
from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore
|
||||
|
||||
try:
|
||||
from mautrix.crypto import StateStore as CryptoStateStore
|
||||
|
||||
class SyncStoreProxy(SyncStore):
|
||||
def __init__(self, db_instance) -> None:
|
||||
self.db_instance = db_instance
|
||||
class PgStateStore(BasePgStateStore, CryptoStateStore):
|
||||
pass
|
||||
|
||||
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
||||
self.db_instance.edit(next_batch=next_batch)
|
||||
except ImportError as e:
|
||||
PgStateStore = BasePgStateStore
|
||||
|
||||
async def get_next_batch(self) -> SyncToken:
|
||||
return self.db_instance.next_batch
|
||||
__all__ = ["PgStateStore"]
|
|
@ -13,17 +13,14 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import TYPE_CHECKING, Dict, List, Set, Type, TypeVar
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
|
||||
from attr import dataclass
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
|
||||
|
||||
from ..__meta__ import __version__
|
||||
from ..plugin_base import Plugin
|
||||
from .meta import PluginMeta
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..instance import PluginInstance
|
||||
|
@ -35,36 +32,6 @@ class IDConflictError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
@serializer(Version)
|
||||
def serialize_version(version: Version) -> str:
|
||||
return str(version)
|
||||
|
||||
|
||||
@deserializer(Version)
|
||||
def deserialize_version(version: str) -> Version:
|
||||
try:
|
||||
return Version(version)
|
||||
except InvalidVersion as e:
|
||||
raise SerializerError("Invalid version") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginMeta(SerializableAttrs):
|
||||
id: str
|
||||
version: Version
|
||||
modules: List[str]
|
||||
main_class: str
|
||||
|
||||
maubot: Version = Version(__version__)
|
||||
database: bool = False
|
||||
config: bool = False
|
||||
webapp: bool = False
|
||||
license: str = ""
|
||||
extra_files: List[str] = []
|
||||
dependencies: List[str] = []
|
||||
soft_dependencies: List[str] = []
|
||||
|
||||
|
||||
class BasePluginLoader(ABC):
|
||||
meta: PluginMeta
|
||||
|
||||
|
@ -80,25 +47,25 @@ class BasePluginLoader(ABC):
|
|||
async def read_file(self, path: str) -> bytes:
|
||||
pass
|
||||
|
||||
def sync_list_files(self, directory: str) -> List[str]:
|
||||
def sync_list_files(self, directory: str) -> list[str]:
|
||||
raise NotImplementedError("This loader doesn't support synchronous operations")
|
||||
|
||||
@abstractmethod
|
||||
async def list_files(self, directory: str) -> List[str]:
|
||||
async def list_files(self, directory: str) -> list[str]:
|
||||
pass
|
||||
|
||||
|
||||
class PluginLoader(BasePluginLoader, ABC):
|
||||
id_cache: Dict[str, "PluginLoader"] = {}
|
||||
id_cache: dict[str, PluginLoader] = {}
|
||||
|
||||
meta: PluginMeta
|
||||
references: Set["PluginInstance"]
|
||||
references: set[PluginInstance]
|
||||
|
||||
def __init__(self):
|
||||
self.references = set()
|
||||
|
||||
@classmethod
|
||||
def find(cls, plugin_id: str) -> "PluginLoader":
|
||||
def find(cls, plugin_id: str) -> PluginLoader:
|
||||
return cls.id_cache[plugin_id]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
@ -119,11 +86,11 @@ class PluginLoader(BasePluginLoader, ABC):
|
|||
)
|
||||
|
||||
@abstractmethod
|
||||
async def load(self) -> Type[PluginClass]:
|
||||
async def load(self) -> type[PluginClass]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reload(self) -> Type[PluginClass]:
|
||||
async def reload(self) -> type[PluginClass]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
53
maubot/loader/meta.py
Normal file
53
maubot/loader/meta.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2022 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import List
|
||||
|
||||
from attr import dataclass
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
|
||||
|
||||
from ..__meta__ import __version__
|
||||
|
||||
|
||||
@serializer(Version)
|
||||
def serialize_version(version: Version) -> str:
|
||||
return str(version)
|
||||
|
||||
|
||||
@deserializer(Version)
|
||||
def deserialize_version(version: str) -> Version:
|
||||
try:
|
||||
return Version(version)
|
||||
except InvalidVersion as e:
|
||||
raise SerializerError("Invalid version") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginMeta(SerializableAttrs):
|
||||
id: str
|
||||
version: Version
|
||||
modules: List[str]
|
||||
main_class: str
|
||||
|
||||
maubot: Version = Version(__version__)
|
||||
database: bool = False
|
||||
config: bool = False
|
||||
webapp: bool = False
|
||||
license: str = ""
|
||||
extra_files: List[str] = []
|
||||
dependencies: List[str] = []
|
||||
soft_dependencies: List[str] = []
|
|
@ -29,7 +29,8 @@ from mautrix.types import SerializerError
|
|||
from ..config import Config
|
||||
from ..lib.zipimport import ZipImportError, zipimporter
|
||||
from ..plugin_base import Plugin
|
||||
from .abc import IDConflictError, PluginClass, PluginLoader, PluginMeta
|
||||
from .abc import IDConflictError, PluginClass, PluginLoader
|
||||
from .meta import PluginMeta
|
||||
|
||||
yaml = YAML()
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from aiohttp import web
|
|||
|
||||
from ...config import Config
|
||||
from .auth import check_token
|
||||
from .base import get_config, routes, set_config, set_loop
|
||||
from .base import get_config, routes, set_config
|
||||
from .middleware import auth, error
|
||||
|
||||
|
||||
|
@ -40,7 +40,6 @@ def features(request: web.Request) -> web.Response:
|
|||
|
||||
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
|
||||
set_config(cfg)
|
||||
set_loop(loop)
|
||||
for pkg, enabled in cfg["api_features"].items():
|
||||
if enabled:
|
||||
importlib.import_module(f"maubot.management.api.{pkg}")
|
||||
|
|
|
@ -46,7 +46,7 @@ def create_token(user: UserID) -> str:
|
|||
def get_token(request: web.Request) -> str:
|
||||
token = request.headers.get("Authorization", "")
|
||||
if not token or not token.startswith("Bearer "):
|
||||
token = request.query.get("access_token", None)
|
||||
token = request.query.get("access_token", "")
|
||||
else:
|
||||
token = token[len("Bearer ") :]
|
||||
return token
|
||||
|
|
|
@ -24,7 +24,6 @@ from ...config import Config
|
|||
|
||||
routes: web.RouteTableDef = web.RouteTableDef()
|
||||
_config: Config | None = None
|
||||
_loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
|
||||
def set_config(config: Config) -> None:
|
||||
|
@ -36,15 +35,6 @@ def get_config() -> Config:
|
|||
return _config
|
||||
|
||||
|
||||
def set_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
global _loop
|
||||
_loop = loop
|
||||
|
||||
|
||||
def get_loop() -> asyncio.AbstractEventLoop:
|
||||
return _loop
|
||||
|
||||
|
||||
@routes.get("/version")
|
||||
async def version(_: web.Request) -> web.Response:
|
||||
return web.json_response({"version": __version__})
|
||||
|
|
|
@ -24,7 +24,6 @@ from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequ
|
|||
from mautrix.types import FilterID, SyncToken, UserID
|
||||
|
||||
from ...client import Client
|
||||
from ...db import DBClient
|
||||
from .base import routes
|
||||
from .responses import resp
|
||||
|
||||
|
@ -37,7 +36,7 @@ async def get_clients(_: web.Request) -> web.Response:
|
|||
@routes.get("/client/{id}")
|
||||
async def get_client(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return resp.client_not_found
|
||||
return resp.found(client.to_dict())
|
||||
|
@ -51,7 +50,6 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
|
|||
mxid="@not:a.mxid",
|
||||
base_url=homeserver,
|
||||
token=access_token,
|
||||
loop=Client.loop,
|
||||
client_session=Client.http_client,
|
||||
)
|
||||
try:
|
||||
|
@ -63,29 +61,23 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
|
|||
except MatrixConnectionError:
|
||||
return resp.bad_client_connection_details
|
||||
if user_id is None:
|
||||
existing_client = Client.get(whoami.user_id, None)
|
||||
existing_client = await Client.get(whoami.user_id)
|
||||
if existing_client is not None:
|
||||
return resp.user_exists
|
||||
elif whoami.user_id != user_id:
|
||||
return resp.mxid_mismatch(whoami.user_id)
|
||||
elif whoami.device_id and device_id and whoami.device_id != device_id:
|
||||
return resp.device_id_mismatch(whoami.device_id)
|
||||
db_instance = DBClient(
|
||||
id=whoami.user_id,
|
||||
homeserver=homeserver,
|
||||
access_token=access_token,
|
||||
enabled=data.get("enabled", True),
|
||||
next_batch=SyncToken(""),
|
||||
filter_id=FilterID(""),
|
||||
sync=data.get("sync", True),
|
||||
autojoin=data.get("autojoin", True),
|
||||
online=data.get("online", True),
|
||||
displayname=data.get("displayname", "disable"),
|
||||
avatar_url=data.get("avatar_url", "disable"),
|
||||
device_id=device_id,
|
||||
client = await Client.get(
|
||||
whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id
|
||||
)
|
||||
client = Client(db_instance)
|
||||
client.db_instance.insert()
|
||||
client.enabled = data.get("enabled", True)
|
||||
client.sync = data.get("sync", True)
|
||||
client.autojoin = data.get("autojoin", True)
|
||||
client.online = data.get("online", True)
|
||||
client.displayname = data.get("displayname", "disable")
|
||||
client.avatar_url = data.get("avatar_url", "disable")
|
||||
await client.update()
|
||||
await client.start()
|
||||
return resp.created(client.to_dict())
|
||||
|
||||
|
@ -93,9 +85,7 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
|
|||
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
|
||||
try:
|
||||
await client.update_access_details(
|
||||
data.get("access_token", None),
|
||||
data.get("homeserver", None),
|
||||
data.get("device_id", None),
|
||||
data.get("access_token"), data.get("homeserver"), data.get("device_id")
|
||||
)
|
||||
except MatrixInvalidToken:
|
||||
return resp.bad_client_access_token
|
||||
|
@ -109,21 +99,21 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
|
|||
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
|
||||
elif str_err.startswith("Device ID mismatch"):
|
||||
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
|
||||
with client.db_instance.edit_mode():
|
||||
await client.update_avatar_url(data.get("avatar_url", None))
|
||||
await client.update_displayname(data.get("displayname", None))
|
||||
await client.update_started(data.get("started", None))
|
||||
client.enabled = data.get("enabled", client.enabled)
|
||||
client.autojoin = data.get("autojoin", client.autojoin)
|
||||
client.online = data.get("online", client.online)
|
||||
client.sync = data.get("sync", client.sync)
|
||||
return resp.updated(client.to_dict(), is_login=is_login)
|
||||
await client.update_avatar_url(data.get("avatar_url"), save=False)
|
||||
await client.update_displayname(data.get("displayname"), save=False)
|
||||
await client.update_started(data.get("started"))
|
||||
await client.update_enabled(data.get("enabled"), save=False)
|
||||
await client.update_autojoin(data.get("autojoin"), save=False)
|
||||
await client.update_online(data.get("online"), save=False)
|
||||
await client.update_sync(data.get("sync"), save=False)
|
||||
await client.update()
|
||||
return resp.updated(client.to_dict(), is_login=is_login)
|
||||
|
||||
|
||||
async def _create_or_update_client(
|
||||
user_id: UserID, data: dict, is_login: bool = False
|
||||
) -> web.Response:
|
||||
client = Client.get(user_id, None)
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return await _create_client(user_id, data)
|
||||
else:
|
||||
|
@ -141,7 +131,7 @@ async def create_client(request: web.Request) -> web.Response:
|
|||
|
||||
@routes.put("/client/{id}")
|
||||
async def update_client(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
user_id = request.match_info["id"]
|
||||
try:
|
||||
data = await request.json()
|
||||
except JSONDecodeError:
|
||||
|
@ -151,23 +141,23 @@ async def update_client(request: web.Request) -> web.Response:
|
|||
|
||||
@routes.delete("/client/{id}")
|
||||
async def delete_client(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
user_id = request.match_info["id"]
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return resp.client_not_found
|
||||
if len(client.references) > 0:
|
||||
return resp.client_in_use
|
||||
if client.started:
|
||||
await client.stop()
|
||||
client.delete()
|
||||
await client.delete()
|
||||
return resp.deleted
|
||||
|
||||
|
||||
@routes.post("/client/{id}/clearcache")
|
||||
async def clear_client_cache(request: web.Request) -> web.Response:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
user_id = request.match_info["id"]
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return resp.client_not_found
|
||||
client.clear_cache()
|
||||
await client.clear_cache()
|
||||
return resp.ok
|
||||
|
|
|
@ -13,7 +13,9 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, NamedTuple, Optional, Tuple
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
from http import HTTPStatus
|
||||
from json import JSONDecodeError
|
||||
import asyncio
|
||||
|
@ -30,12 +32,12 @@ from mautrix.client import ClientAPI
|
|||
from mautrix.errors import MatrixRequestError
|
||||
from mautrix.types import LoginResponse, LoginType
|
||||
|
||||
from .base import get_config, get_loop, routes
|
||||
from .base import get_config, routes
|
||||
from .client import _create_client, _create_or_update_client
|
||||
from .responses import resp
|
||||
|
||||
|
||||
def known_homeservers() -> Dict[str, Dict[str, str]]:
|
||||
def known_homeservers() -> dict[str, dict[str, str]]:
|
||||
return get_config()["homeservers"]
|
||||
|
||||
|
||||
|
@ -61,7 +63,7 @@ truthy_strings = ("1", "true", "yes")
|
|||
|
||||
async def read_client_auth_request(
|
||||
request: web.Request,
|
||||
) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]:
|
||||
) -> tuple[AuthRequestInfo | None, web.Response | None]:
|
||||
server_name = request.match_info.get("server", None)
|
||||
server = known_homeservers().get(server_name, None)
|
||||
if not server:
|
||||
|
@ -85,7 +87,7 @@ async def read_client_auth_request(
|
|||
return (
|
||||
AuthRequestInfo(
|
||||
server_name=server_name,
|
||||
client=ClientAPI(base_url=base_url, loop=get_loop()),
|
||||
client=ClientAPI(base_url=base_url),
|
||||
secret=server.get("secret"),
|
||||
username=username,
|
||||
password=password,
|
||||
|
@ -189,11 +191,11 @@ async def _do_sso(req: AuthRequestInfo) -> web.Response:
|
|||
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
|
||||
{"redirectUrl": str(public_url)}
|
||||
)
|
||||
sso_waiters[waiter_id] = req, get_loop().create_future()
|
||||
sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future()
|
||||
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
|
||||
|
||||
|
||||
async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
|
||||
async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response:
|
||||
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
||||
device_id = f"maubot_{device_id}"
|
||||
try:
|
||||
|
@ -235,7 +237,7 @@ async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) ->
|
|||
return web.json_response(res.serialize())
|
||||
|
||||
|
||||
sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
|
||||
sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {}
|
||||
|
||||
|
||||
@routes.post("/client/auth/{server}/sso/{id}/wait")
|
||||
|
|
|
@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024
|
|||
@routes.view("/proxy/{id}/{path:_matrix/.+}")
|
||||
async def proxy(request: web.Request) -> web.StreamResponse:
|
||||
user_id = request.match_info.get("id", None)
|
||||
client = Client.get(user_id, None)
|
||||
client = await Client.get(user_id)
|
||||
if not client:
|
||||
return resp.client_not_found
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ from json import JSONDecodeError
|
|||
from aiohttp import web
|
||||
|
||||
from ...client import Client
|
||||
from ...db import DBPlugin
|
||||
from ...instance import PluginInstance
|
||||
from ...loader import PluginLoader
|
||||
from .base import routes
|
||||
|
@ -32,56 +31,49 @@ async def get_instances(_: web.Request) -> web.Response:
|
|||
|
||||
@routes.get("/instance/{id}")
|
||||
async def get_instance(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "").lower()
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
return resp.found(instance.to_dict())
|
||||
|
||||
|
||||
async def _create_instance(instance_id: str, data: dict) -> web.Response:
|
||||
plugin_type = data.get("type", None)
|
||||
primary_user = data.get("primary_user", None)
|
||||
plugin_type = data.get("type")
|
||||
primary_user = data.get("primary_user")
|
||||
if not plugin_type:
|
||||
return resp.plugin_type_required
|
||||
elif not primary_user:
|
||||
return resp.primary_user_required
|
||||
elif not Client.get(primary_user):
|
||||
elif not await Client.get(primary_user):
|
||||
return resp.primary_user_not_found
|
||||
try:
|
||||
PluginLoader.find(plugin_type)
|
||||
except KeyError:
|
||||
return resp.plugin_type_not_found
|
||||
db_instance = DBPlugin(
|
||||
id=instance_id,
|
||||
type=plugin_type,
|
||||
enabled=data.get("enabled", True),
|
||||
primary_user=primary_user,
|
||||
config=data.get("config", ""),
|
||||
)
|
||||
instance = PluginInstance(db_instance)
|
||||
instance.load()
|
||||
instance.db_instance.insert()
|
||||
instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user)
|
||||
instance.enabled = data.get("enabled", True)
|
||||
instance.config_str = data.get("config") or ""
|
||||
await instance.update()
|
||||
await instance.start()
|
||||
return resp.created(instance.to_dict())
|
||||
|
||||
|
||||
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
|
||||
if not await instance.update_primary_user(data.get("primary_user", None)):
|
||||
if not await instance.update_primary_user(data.get("primary_user")):
|
||||
return resp.primary_user_not_found
|
||||
with instance.db_instance.edit_mode():
|
||||
instance.update_id(data.get("id", None))
|
||||
instance.update_enabled(data.get("enabled", None))
|
||||
instance.update_config(data.get("config", None))
|
||||
await instance.update_started(data.get("started", None))
|
||||
await instance.update_type(data.get("type", None))
|
||||
return resp.updated(instance.to_dict())
|
||||
await instance.update_id(data.get("id"))
|
||||
await instance.update_enabled(data.get("enabled"))
|
||||
await instance.update_config(data.get("config"))
|
||||
await instance.update_started(data.get("started"))
|
||||
await instance.update_type(data.get("type"))
|
||||
return resp.updated(instance.to_dict())
|
||||
|
||||
|
||||
@routes.put("/instance/{id}")
|
||||
async def update_instance(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "").lower()
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
try:
|
||||
data = await request.json()
|
||||
except JSONDecodeError:
|
||||
|
@ -94,11 +86,11 @@ async def update_instance(request: web.Request) -> web.Response:
|
|||
|
||||
@routes.delete("/instance/{id}")
|
||||
async def delete_instance(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "").lower()
|
||||
instance = PluginInstance.get(instance_id)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
if instance.started:
|
||||
await instance.stop()
|
||||
instance.delete()
|
||||
await instance.delete()
|
||||
return resp.deleted
|
||||
|
|
|
@ -29,8 +29,8 @@ from .responses import resp
|
|||
|
||||
@routes.get("/instance/{id}/database")
|
||||
async def get_database(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "")
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
elif not instance.inst_db:
|
||||
|
@ -65,8 +65,8 @@ def check_type(val):
|
|||
|
||||
@routes.get("/instance/{id}/database/{table}")
|
||||
async def get_table(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "")
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
elif not instance.inst_db:
|
||||
|
@ -86,14 +86,14 @@ async def get_table(request: web.Request) -> web.Response:
|
|||
]
|
||||
except KeyError:
|
||||
order = []
|
||||
limit = int(request.query.get("limit", 100))
|
||||
limit = int(request.query.get("limit", "100"))
|
||||
return execute_query(instance, table.select().order_by(*order).limit(limit))
|
||||
|
||||
|
||||
@routes.post("/instance/{id}/database/query")
|
||||
async def query(request: web.Request) -> web.Response:
|
||||
instance_id = request.match_info.get("id", "")
|
||||
instance = PluginInstance.get(instance_id, None)
|
||||
instance_id = request.match_info["id"].lower()
|
||||
instance = await PluginInstance.get(instance_id)
|
||||
if not instance:
|
||||
return resp.instance_not_found
|
||||
elif not instance.inst_db:
|
||||
|
|
|
@ -23,7 +23,7 @@ import logging
|
|||
from aiohttp import web, web_ws
|
||||
|
||||
from .auth import is_valid_token
|
||||
from .base import get_loop, routes
|
||||
from .base import routes
|
||||
|
||||
BUILTIN_ATTRS = {
|
||||
"args",
|
||||
|
@ -138,12 +138,12 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
|
|||
authenticated = False
|
||||
|
||||
async def close_if_not_authenticated():
|
||||
await asyncio.sleep(5, loop=get_loop())
|
||||
await asyncio.sleep(5)
|
||||
if not authenticated:
|
||||
await ws.close(code=4000)
|
||||
log.debug(f"Connection from {request.remote} terminated due to no authentication")
|
||||
|
||||
asyncio.ensure_future(close_if_not_authenticated())
|
||||
asyncio.create_task(close_if_not_authenticated())
|
||||
|
||||
try:
|
||||
msg: web_ws.WSMessage
|
||||
|
|
|
@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response:
|
|||
|
||||
@routes.get("/plugin/{id}")
|
||||
async def get_plugin(request: web.Request) -> web.Response:
|
||||
plugin_id = request.match_info.get("id", None)
|
||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
||||
plugin_id = request.match_info["id"]
|
||||
plugin = PluginLoader.id_cache.get(plugin_id)
|
||||
if not plugin:
|
||||
return resp.plugin_not_found
|
||||
return resp.found(plugin.to_dict())
|
||||
|
@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response:
|
|||
|
||||
@routes.delete("/plugin/{id}")
|
||||
async def delete_plugin(request: web.Request) -> web.Response:
|
||||
plugin_id = request.match_info.get("id", None)
|
||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
||||
plugin_id = request.match_info["id"]
|
||||
plugin = PluginLoader.id_cache.get(plugin_id)
|
||||
if not plugin:
|
||||
return resp.plugin_not_found
|
||||
elif len(plugin.references) > 0:
|
||||
|
@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response:
|
|||
|
||||
@routes.post("/plugin/{id}/reload")
|
||||
async def reload_plugin(request: web.Request) -> web.Response:
|
||||
plugin_id = request.match_info.get("id", None)
|
||||
plugin = PluginLoader.id_cache.get(plugin_id, None)
|
||||
plugin_id = request.match_info["id"]
|
||||
plugin = PluginLoader.id_cache.get(plugin_id)
|
||||
if not plugin:
|
||||
return resp.plugin_not_found
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ from .responses import resp
|
|||
|
||||
@routes.put("/plugin/{id}")
|
||||
async def put_plugin(request: web.Request) -> web.Response:
|
||||
plugin_id = request.match_info.get("id", None)
|
||||
plugin_id = request.match_info["id"]
|
||||
content = await request.read()
|
||||
file = BytesIO(content)
|
||||
try:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Awaitable
|
||||
from abc import ABC
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
|
@ -124,6 +124,7 @@ class Plugin(ABC):
|
|||
def get_config_class(cls) -> type[BaseProxyConfig] | None:
|
||||
return None
|
||||
|
||||
def on_external_config_update(self) -> None:
|
||||
def on_external_config_update(self) -> Awaitable[None] | None:
|
||||
if self.config:
|
||||
self.config.load_and_update()
|
||||
return None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue