Refactor __main__.py and fix things

This commit is contained in:
Tulir Asokan 2021-11-19 15:22:54 +02:00
parent c685eb5e08
commit 7c9668d8bc
9 changed files with 84 additions and 117 deletions

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2021 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,12 +13,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import logging.config
import argparse
import asyncio
import signal
import copy
import sys
from mautrix.util.program import Program
from .config import Config
from .db import init as init_db
@ -27,70 +24,58 @@ from .client import Client, init as init_client_class
from .loader.zip import init as init_zip_loader
from .instance import init as init_plugin_instance_class
from .management.api import init as init_mgmt_api
from .lib.future_awaitable import FutureAwaitable
from .__meta__ import __version__
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
prog="python -m maubot")
parser.add_argument("-c", "--config", type=str, default="config.yaml",
metavar="<path>", help="the path to your config file")
parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
metavar="<path>", help="the path to the example config "
"(for automatic config updates)")
args = parser.parse_args()
config = Config(args.config, args.base_config)
config.load()
config.update()
class Maubot(Program):
config: Config
server: MaubotServer
logging.config.dictConfig(copy.deepcopy(config["logging"]))
config_class = Config
module = "maubot"
name = "maubot"
version = __version__
command = "python -m maubot"
description = "A plugin-based Matrix bot system."
loop = asyncio.get_event_loop()
def prepare_log_websocket(self) -> None:
from .management.api.log import init, stop_all
init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all))
stop_log_listener = None
if config["api_features.log"]:
from .management.api.log import init as init_log_listener, stop_all as stop_log_listener
def prepare(self) -> None:
super().prepare()
init_log_listener(loop)
if self.config["api_features.log"]:
self.prepare_log_websocket()
log = logging.getLogger("maubot.init")
log.info(f"Initializing maubot {__version__}")
init_zip_loader(self.config)
init_db(self.config)
clients = init_client_class(self.config, self.loop)
self.add_startup_actions(*(client.start() for client in clients))
management_api = init_mgmt_api(self.config, self.loop)
self.server = MaubotServer(management_api, self.config, self.loop)
init_zip_loader(config)
db_engine = init_db(config)
clients = init_client_class(config, loop)
management_api = init_mgmt_api(config, loop)
server = MaubotServer(management_api, config, loop)
plugins = init_plugin_instance_class(config, server, loop)
plugins = init_plugin_instance_class(self.config, self.server, self.loop)
for plugin in plugins:
plugin.load()
for plugin in plugins:
plugin.load()
async def start(self) -> None:
if Client.crypto_db:
self.log.debug("Starting client crypto database")
await Client.crypto_db.start()
await super().start()
await self.server.start()
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
async def stop(self) -> None:
self.add_shutdown_actions(*(client.stop() for client in Client.cache.values()))
await super().stop()
self.log.debug("Stopping server")
try:
await asyncio.wait_for(self.server.stop(), 5)
except asyncio.TimeoutError:
self.log.warning("Stopping server timed out")
try:
log.info("Starting server")
loop.run_until_complete(server.start())
if Client.crypto_db:
log.debug("Starting client crypto database")
loop.run_until_complete(Client.crypto_db.start())
log.info("Starting clients and plugins")
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
log.info("Startup actions complete, running forever")
loop.run_forever()
except KeyboardInterrupt:
log.info("Interrupt received, stopping clients")
loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()]))
if stop_log_listener is not None:
log.debug("Closing websockets")
loop.run_until_complete(stop_log_listener())
log.debug("Stopping server")
try:
loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop))
except asyncio.TimeoutError:
log.warning("Stopping server timed out")
log.debug("Closing event loop")
loop.close()
log.debug("Everything stopped, shutting down")
sys.exit(0)
Maubot().run()

View file

@ -14,17 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
from os import path
import asyncio
import logging
from aiohttp import ClientSession
from yarl import URL
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter)
PresenceState, StateFilter, DeviceID)
from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
@ -33,13 +31,12 @@ from .db import DBClient
from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
PickleCryptoStore)
from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, CryptoStore
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
except ImportError:
except ImportError as e:
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
SQLStateStore = BaseSQLStateStore
@ -63,8 +60,7 @@ class Client:
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
crypto_pickle_dir: str = None
crypto_db: 'AsyncDatabase' = None
crypto_db: Optional['AsyncDatabase'] = None
references: Set['PluginInstance']
db_instance: DBClient
@ -90,7 +86,7 @@ class Client:
log=self.log, loop=self.loop, device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store)
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
if OlmMachine and self.device_id and self.crypto_db:
self.crypto_store = self._make_crypto_store()
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
@ -109,9 +105,6 @@ class Client:
def _make_crypto_store(self) -> 'CryptoStore':
if self.crypto_db:
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
elif self.crypto_pickle_dir:
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
raise ValueError("Crypto database not configured")
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
@ -330,7 +323,7 @@ class Client:
return self.db_instance.access_token
@property
def device_id(self) -> str:
def device_id(self) -> DeviceID:
return self.db_instance.device_id
@property
@ -403,25 +396,9 @@ def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.loop = loop
if OlmMachine:
db_type = config["crypto_database.type"]
if db_type == "default":
db_url = config["crypto_database"]
if db_url == "default":
db_url = config["database"]
parsed_url = URL(db_url)
if parsed_url.scheme == "sqlite":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif parsed_url.scheme == "postgres" or parsed_url.scheme == "postgresql":
if not PgCryptoStore:
log.warning("Default database is postgres, but asyncpg is not installed. "
"Encryption will not work.")
else:
Client.crypto_db = AsyncDatabase(url=db_url,
upgrade_table=PgCryptoStore.upgrade_table)
elif db_type == "pickle":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif (db_type == "postgres" or db_type == "postgresql") and PgCryptoStore:
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
upgrade_table=PgCryptoStore.upgrade_table)
else:
raise ValueError("Unsupported crypto database type")
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
return Client.all()

View file

@ -32,9 +32,11 @@ class Config(BaseFileConfig):
base = helper.base
copy = helper.copy
copy("database")
copy("crypto_database.type")
copy("crypto_database.postgres_uri")
copy("crypto_database.pickle_dir")
if isinstance(self["crypto_database"], dict):
if self["crypto_database.type"] == "postgres":
base["crypto_database"] = self["crypto_database.postgres_uri"]
else:
copy("crypto_database")
copy("plugin_directories.upload")
copy("plugin_directories.load")
copy("plugin_directories.trash")

102
maubot/example-config.yaml Normal file
View file

@ -0,0 +1,102 @@
# The full URI to the database. SQLite and Postgres are fully supported.
# Other DBMSes supported by SQLAlchemy may or may not work.
# Format examples:
# SQLite: sqlite:///filename.db
# Postgres: postgresql://username:password@hostname/dbname
database: sqlite:///maubot.db
# Separate database URL for the crypto database. "default" means use the same database as above.
crypto_database: default
plugin_directories:
# The directory where uploaded new plugins should be stored.
upload: ./plugins
# The directories from which plugins should be loaded.
# Duplicate plugin IDs will be moved to the trash.
load:
- ./plugins
# The directory where old plugin versions and conflicting plugins should be moved.
# Set to "delete" to delete files immediately.
trash: ./trash
# The directory where plugin databases should be stored.
db: ./plugins
server:
# The IP and port to listen to.
hostname: 0.0.0.0
port: 29316
# Public base URL where the server is visible.
public_url: https://example.com
# The base management API path.
base_path: /_matrix/maubot/v1
# The base path for the UI.
ui_base_path: /_matrix/maubot
# The base path for plugin endpoints. The instance ID will be appended directly.
plugin_base_path: /_matrix/maubot/plugin/
# Override path from where to load UI resources.
# Set to false to using pkg_resources to find the path.
override_resource_path: false
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.
appservice_base_path: /_matrix/app/v1
# The shared secret to sign API access tokens.
# Set to "generate" to generate and save a new token at startup.
unshared_secret: generate
# Shared registration secrets to allow registering new users from the management UI
registration_secrets:
example.com:
# Client-server API URL
url: https://example.com
# registration_shared_secret from synapse config
secret: synapse_shared_registration_secret
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
# to prevent normal login. Root is a special user that can't have a password and will always exist.
admins:
root: ""
# API feature switches.
api_features:
login: true
plugin: true
plugin_upload: true
instance: true
instance_database: true
client: true
client_proxy: true
client_auth: true
dev_open: true
log: true
# Python logging configuration.
#
# See section 16.7.2 of the Python documentation for more info:
# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema
logging:
version: 1
formatters:
colored:
(): maubot.lib.color_log.ColorFormatter
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
normal:
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: normal
filename: ./maubot.log
maxBytes: 10485760
backupCount: 10
console:
class: logging.StreamHandler
formatter: colored
loggers:
maubot:
level: DEBUG
mau:
level: DEBUG
aiohttp:
level: INFO
root:
level: DEBUG
handlers: [file, console]

View file

@ -0,0 +1,9 @@
from typing import Callable, Awaitable, Generator, Any
class FutureAwaitable:
def __init__(self, func: Callable[[], Awaitable[None]]) -> None:
self._func = func
def __await__(self) -> Generator[Any, None, None]:
return self._func().__await__()

View file

@ -93,6 +93,7 @@ def init(loop: asyncio.AbstractEventLoop) -> None:
async def stop_all() -> None:
log.debug("Closing log listener websockets")
log_root.removeHandler(handler)
for socket in sockets:
try: