Switch to asyncpg/aiosqlite

Fixes #142
Fixes #98
Probably fixes #62
This commit is contained in:
Tulir Asokan 2022-03-25 19:45:48 +02:00
parent 068e268c63
commit 21ed971d2f
43 changed files with 911 additions and 955 deletions

View file

@ -15,8 +15,10 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable
from asyncio import AbstractEventLoop
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
from collections import defaultdict
import asyncio
import inspect
import io
import logging
import os.path
@ -26,16 +28,17 @@ from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql
from mautrix.types import UserID
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from .client import Client
from .config import Config
from .db import DBPlugin
from .db import Instance as DBInstance
from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin
if TYPE_CHECKING:
from .server import MaubotServer, PluginWebApp
from .__main__ import Maubot
from .server import PluginWebApp
log = logging.getLogger("maubot.instance")
@ -44,29 +47,42 @@ yaml.indent(4)
yaml.width = 200
class PluginInstance:
webserver: MaubotServer = None
mb_config: Config = None
loop: AbstractEventLoop = None
class PluginInstance(DBInstance):
maubot: "Maubot" = None
cache: dict[str, PluginInstance] = {}
plugin_directories: list[str] = []
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
log: logging.Logger
loader: PluginLoader
client: Client
plugin: Plugin
config: BaseProxyConfig
loader: PluginLoader | None
client: Client | None
plugin: Plugin | None
config: BaseProxyConfig | None
base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None
inst_db: sql.engine.Engine
inst_db_tables: dict[str, sql.Table]
inst_db: sql.engine.Engine | None
inst_db_tables: dict[str, sql.Table] | None
inst_webapp: PluginWebApp | None
inst_webapp_url: str | None
started: bool
def __init__(self, db_instance: DBPlugin):
self.db_instance = db_instance
def __init__(
self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = ""
) -> None:
super().__init__(
id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config
)
def __hash__(self) -> int:
return hash(self.id)
@classmethod
def init_cls(cls, maubot: "Maubot") -> None:
cls.maubot = maubot
def postinit(self) -> None:
self.log = log.getChild(self.id)
self.cache[self.id] = self
self.config = None
self.started = False
self.loader = None
@ -78,7 +94,6 @@ class PluginInstance:
self.inst_webapp_url = None
self.base_cfg = None
self.base_cfg_str = None
self.cache[self.id] = self
def to_dict(self) -> dict:
return {
@ -87,10 +102,10 @@ class PluginInstance:
"enabled": self.enabled,
"started": self.started,
"primary_user": self.primary_user,
"config": self.db_instance.config,
"config": self.config_str,
"base_config": self.base_cfg_str,
"database": (
self.inst_db is not None and self.mb_config["api_features.instance_database"]
self.inst_db is not None and self.maubot.config["api_features.instance_database"]
),
}
@ -101,19 +116,19 @@ class PluginInstance:
self.inst_db_tables = metadata.tables
return self.inst_db_tables
def load(self) -> bool:
async def load(self) -> bool:
if not self.loader:
try:
self.loader = PluginLoader.find(self.type)
except KeyError:
self.log.error(f"Failed to find loader for type {self.type}")
self.db_instance.enabled = False
await self.update_enabled(False)
return False
if not self.client:
self.client = Client.get(self.primary_user)
self.client = await Client.get(self.primary_user)
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
self.db_instance.enabled = False
await self.update_enabled(False)
return False
if self.loader.meta.database:
self.enable_database()
@ -125,18 +140,18 @@ class PluginInstance:
return True
def enable_webapp(self) -> None:
self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id)
self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
def disable_webapp(self) -> None:
self.webserver.remove_instance_webapp(self.id)
self.maubot.server.remove_instance_webapp(self.id)
self.inst_webapp = None
self.inst_webapp_url = None
def enable_database(self) -> None:
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
def delete(self) -> None:
async def delete(self) -> None:
if self.loader is not None:
self.loader.references.remove(self)
if self.client is not None:
@ -145,23 +160,23 @@ class PluginInstance:
del self.cache[self.id]
except KeyError:
pass
self.db_instance.delete()
await super().delete()
if self.inst_db:
self.inst_db.dispose()
ZippedPluginLoader.trash(
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted",
)
if self.inst_webapp:
self.disable_webapp()
def load_config(self) -> CommentedMap:
return yaml.load(self.db_instance.config)
return yaml.load(self.config_str)
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO()
yaml.dump(data, buf)
self.db_instance.config = buf.getvalue()
self.config_str = buf.getvalue()
async def start(self) -> None:
if self.started:
@ -172,7 +187,7 @@ class PluginInstance:
return
if not self.client or not self.loader:
self.log.warning("Missing plugin instance dependencies, attempting to load...")
if not self.load():
if not await self.load():
return
cls = await self.loader.load()
if self.loader.meta.webapp and self.inst_webapp is None:
@ -205,7 +220,7 @@ class PluginInstance:
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
self.plugin = cls(
client=self.client.client,
loop=self.loop,
loop=self.maubot.loop,
http=self.client.http_client,
instance_id=self.id,
log=self.log,
@ -219,7 +234,7 @@ class PluginInstance:
await self.plugin.internal_start()
except Exception:
self.log.exception("Failed to start instance")
self.db_instance.enabled = False
await self.update_enabled(False)
return
self.started = True
self.inst_db_tables = None
@ -241,60 +256,51 @@ class PluginInstance:
self.plugin = None
self.inst_db_tables = None
@classmethod
def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None:
try:
return cls.cache[instance_id]
except KeyError:
db_instance = db_instance or DBPlugin.get(instance_id)
if not db_instance:
return None
return PluginInstance(db_instance)
async def update_id(self, new_id: str | None) -> None:
if new_id is not None and new_id.lower() != self.id:
await super().update_id(new_id.lower())
@classmethod
def all(cls) -> Iterable[PluginInstance]:
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None:
if new_id is not None and new_id != self.id:
self.db_instance.id = new_id.lower()
def update_config(self, config: str) -> None:
if not config or self.db_instance.config == config:
async def update_config(self, config: str | None) -> None:
if config is None or self.config_str == config:
return
self.db_instance.config = config
self.config_str = config
if self.started and self.plugin is not None:
self.plugin.on_external_config_update()
res = self.plugin.on_external_config_update()
if inspect.isawaitable(res):
await res
await self.update()
async def update_primary_user(self, primary_user: UserID) -> bool:
if not primary_user or primary_user == self.primary_user:
async def update_primary_user(self, primary_user: UserID | None) -> bool:
if primary_user is None or primary_user == self.primary_user:
return True
client = Client.get(primary_user)
client = await Client.get(primary_user)
if not client:
return False
await self.stop()
self.db_instance.primary_user = client.id
self.primary_user = client.id
if self.client:
self.client.references.remove(self)
self.client = client
self.client.references.add(self)
await self.update()
await self.start()
self.log.debug(f"Primary user switched to {self.client.id}")
return True
async def update_type(self, type: str) -> bool:
if not type or type == self.type:
async def update_type(self, type: str | None) -> bool:
if type is None or type == self.type:
return True
try:
loader = PluginLoader.find(type)
except KeyError:
return False
await self.stop()
self.db_instance.type = loader.meta.id
self.type = loader.meta.id
if self.loader:
self.loader.references.remove(self)
self.loader = loader
self.loader.references.add(self)
await self.update()
await self.start()
self.log.debug(f"Type switched to {self.loader.meta.id}")
return True
@ -303,39 +309,41 @@ class PluginInstance:
if started is not None and started != self.started:
await (self.start() if started else self.stop())
def update_enabled(self, enabled: bool) -> None:
async def update_enabled(self, enabled: bool) -> None:
if enabled is not None and enabled != self.enabled:
self.db_instance.enabled = enabled
self.enabled = enabled
await self.update()
# region Properties
@classmethod
@async_getter_lock
async def get(
cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
) -> PluginInstance | None:
try:
return cls.cache[instance_id]
except KeyError:
pass
@property
def id(self) -> str:
return self.db_instance.id
instance = cast(cls, await super().get(instance_id))
if instance is not None:
instance.postinit()
return instance
@id.setter
def id(self, value: str) -> None:
self.db_instance.id = value
if type and primary_user:
instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user)
await instance.insert()
instance.postinit()
return instance
@property
def type(self) -> str:
return self.db_instance.type
return None
@property
def enabled(self) -> bool:
return self.db_instance.enabled
@property
def primary_user(self) -> UserID:
return self.db_instance.primary_user
# endregion
def init(
config: Config, webserver: MaubotServer, loop: AbstractEventLoop
) -> Iterable[PluginInstance]:
PluginInstance.mb_config = config
PluginInstance.loop = loop
PluginInstance.webserver = webserver
return PluginInstance.all()
@classmethod
async def all(cls) -> AsyncGenerator[PluginInstance, None]:
instances = await super().all()
instance: PluginInstance
for instance in instances:
try:
yield cls.cache[instance.id]
except KeyError:
instance.postinit()
yield instance