mirror of
https://github.com/maubot/maubot
synced 2025-09-04 19:30:38 +00:00
parent
068e268c63
commit
21ed971d2f
43 changed files with 911 additions and 955 deletions
|
@ -15,8 +15,10 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
from asyncio import AbstractEventLoop
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
import inspect
|
||||
import io
|
||||
import logging
|
||||
import os.path
|
||||
|
@ -26,16 +28,17 @@ from ruamel.yaml.comments import CommentedMap
|
|||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.util.async_getter_lock import async_getter_lock
|
||||
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
||||
|
||||
from .client import Client
|
||||
from .config import Config
|
||||
from .db import DBPlugin
|
||||
from .db import Instance as DBInstance
|
||||
from .loader import PluginLoader, ZippedPluginLoader
|
||||
from .plugin_base import Plugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .server import MaubotServer, PluginWebApp
|
||||
from .__main__ import Maubot
|
||||
from .server import PluginWebApp
|
||||
|
||||
log = logging.getLogger("maubot.instance")
|
||||
|
||||
|
@ -44,29 +47,42 @@ yaml.indent(4)
|
|||
yaml.width = 200
|
||||
|
||||
|
||||
class PluginInstance:
|
||||
webserver: MaubotServer = None
|
||||
mb_config: Config = None
|
||||
loop: AbstractEventLoop = None
|
||||
class PluginInstance(DBInstance):
|
||||
maubot: "Maubot" = None
|
||||
cache: dict[str, PluginInstance] = {}
|
||||
plugin_directories: list[str] = []
|
||||
_async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
|
||||
|
||||
log: logging.Logger
|
||||
loader: PluginLoader
|
||||
client: Client
|
||||
plugin: Plugin
|
||||
config: BaseProxyConfig
|
||||
loader: PluginLoader | None
|
||||
client: Client | None
|
||||
plugin: Plugin | None
|
||||
config: BaseProxyConfig | None
|
||||
base_cfg: RecursiveDict[CommentedMap] | None
|
||||
base_cfg_str: str | None
|
||||
inst_db: sql.engine.Engine
|
||||
inst_db_tables: dict[str, sql.Table]
|
||||
inst_db: sql.engine.Engine | None
|
||||
inst_db_tables: dict[str, sql.Table] | None
|
||||
inst_webapp: PluginWebApp | None
|
||||
inst_webapp_url: str | None
|
||||
started: bool
|
||||
|
||||
def __init__(self, db_instance: DBPlugin):
|
||||
self.db_instance = db_instance
|
||||
def __init__(
|
||||
self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = ""
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
@classmethod
|
||||
def init_cls(cls, maubot: "Maubot") -> None:
|
||||
cls.maubot = maubot
|
||||
|
||||
def postinit(self) -> None:
|
||||
self.log = log.getChild(self.id)
|
||||
self.cache[self.id] = self
|
||||
self.config = None
|
||||
self.started = False
|
||||
self.loader = None
|
||||
|
@ -78,7 +94,6 @@ class PluginInstance:
|
|||
self.inst_webapp_url = None
|
||||
self.base_cfg = None
|
||||
self.base_cfg_str = None
|
||||
self.cache[self.id] = self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
@ -87,10 +102,10 @@ class PluginInstance:
|
|||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"primary_user": self.primary_user,
|
||||
"config": self.db_instance.config,
|
||||
"config": self.config_str,
|
||||
"base_config": self.base_cfg_str,
|
||||
"database": (
|
||||
self.inst_db is not None and self.mb_config["api_features.instance_database"]
|
||||
self.inst_db is not None and self.maubot.config["api_features.instance_database"]
|
||||
),
|
||||
}
|
||||
|
||||
|
@ -101,19 +116,19 @@ class PluginInstance:
|
|||
self.inst_db_tables = metadata.tables
|
||||
return self.inst_db_tables
|
||||
|
||||
def load(self) -> bool:
|
||||
async def load(self) -> bool:
|
||||
if not self.loader:
|
||||
try:
|
||||
self.loader = PluginLoader.find(self.type)
|
||||
except KeyError:
|
||||
self.log.error(f"Failed to find loader for type {self.type}")
|
||||
self.db_instance.enabled = False
|
||||
await self.update_enabled(False)
|
||||
return False
|
||||
if not self.client:
|
||||
self.client = Client.get(self.primary_user)
|
||||
self.client = await Client.get(self.primary_user)
|
||||
if not self.client:
|
||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||
self.db_instance.enabled = False
|
||||
await self.update_enabled(False)
|
||||
return False
|
||||
if self.loader.meta.database:
|
||||
self.enable_database()
|
||||
|
@ -125,18 +140,18 @@ class PluginInstance:
|
|||
return True
|
||||
|
||||
def enable_webapp(self) -> None:
|
||||
self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id)
|
||||
self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
|
||||
|
||||
def disable_webapp(self) -> None:
|
||||
self.webserver.remove_instance_webapp(self.id)
|
||||
self.maubot.server.remove_instance_webapp(self.id)
|
||||
self.inst_webapp = None
|
||||
self.inst_webapp_url = None
|
||||
|
||||
def enable_database(self) -> None:
|
||||
db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
|
||||
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
|
||||
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
|
||||
|
||||
def delete(self) -> None:
|
||||
async def delete(self) -> None:
|
||||
if self.loader is not None:
|
||||
self.loader.references.remove(self)
|
||||
if self.client is not None:
|
||||
|
@ -145,23 +160,23 @@ class PluginInstance:
|
|||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db_instance.delete()
|
||||
await super().delete()
|
||||
if self.inst_db:
|
||||
self.inst_db.dispose()
|
||||
ZippedPluginLoader.trash(
|
||||
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
|
||||
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
|
||||
reason="deleted",
|
||||
)
|
||||
if self.inst_webapp:
|
||||
self.disable_webapp()
|
||||
|
||||
def load_config(self) -> CommentedMap:
|
||||
return yaml.load(self.db_instance.config)
|
||||
return yaml.load(self.config_str)
|
||||
|
||||
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
|
||||
buf = io.StringIO()
|
||||
yaml.dump(data, buf)
|
||||
self.db_instance.config = buf.getvalue()
|
||||
self.config_str = buf.getvalue()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self.started:
|
||||
|
@ -172,7 +187,7 @@ class PluginInstance:
|
|||
return
|
||||
if not self.client or not self.loader:
|
||||
self.log.warning("Missing plugin instance dependencies, attempting to load...")
|
||||
if not self.load():
|
||||
if not await self.load():
|
||||
return
|
||||
cls = await self.loader.load()
|
||||
if self.loader.meta.webapp and self.inst_webapp is None:
|
||||
|
@ -205,7 +220,7 @@ class PluginInstance:
|
|||
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
|
||||
self.plugin = cls(
|
||||
client=self.client.client,
|
||||
loop=self.loop,
|
||||
loop=self.maubot.loop,
|
||||
http=self.client.http_client,
|
||||
instance_id=self.id,
|
||||
log=self.log,
|
||||
|
@ -219,7 +234,7 @@ class PluginInstance:
|
|||
await self.plugin.internal_start()
|
||||
except Exception:
|
||||
self.log.exception("Failed to start instance")
|
||||
self.db_instance.enabled = False
|
||||
await self.update_enabled(False)
|
||||
return
|
||||
self.started = True
|
||||
self.inst_db_tables = None
|
||||
|
@ -241,60 +256,51 @@ class PluginInstance:
|
|||
self.plugin = None
|
||||
self.inst_db_tables = None
|
||||
|
||||
@classmethod
|
||||
def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None:
|
||||
try:
|
||||
return cls.cache[instance_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBPlugin.get(instance_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return PluginInstance(db_instance)
|
||||
async def update_id(self, new_id: str | None) -> None:
|
||||
if new_id is not None and new_id.lower() != self.id:
|
||||
await super().update_id(new_id.lower())
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable[PluginInstance]:
|
||||
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
|
||||
|
||||
def update_id(self, new_id: str) -> None:
|
||||
if new_id is not None and new_id != self.id:
|
||||
self.db_instance.id = new_id.lower()
|
||||
|
||||
def update_config(self, config: str) -> None:
|
||||
if not config or self.db_instance.config == config:
|
||||
async def update_config(self, config: str | None) -> None:
|
||||
if config is None or self.config_str == config:
|
||||
return
|
||||
self.db_instance.config = config
|
||||
self.config_str = config
|
||||
if self.started and self.plugin is not None:
|
||||
self.plugin.on_external_config_update()
|
||||
res = self.plugin.on_external_config_update()
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
await self.update()
|
||||
|
||||
async def update_primary_user(self, primary_user: UserID) -> bool:
|
||||
if not primary_user or primary_user == self.primary_user:
|
||||
async def update_primary_user(self, primary_user: UserID | None) -> bool:
|
||||
if primary_user is None or primary_user == self.primary_user:
|
||||
return True
|
||||
client = Client.get(primary_user)
|
||||
client = await Client.get(primary_user)
|
||||
if not client:
|
||||
return False
|
||||
await self.stop()
|
||||
self.db_instance.primary_user = client.id
|
||||
self.primary_user = client.id
|
||||
if self.client:
|
||||
self.client.references.remove(self)
|
||||
self.client = client
|
||||
self.client.references.add(self)
|
||||
await self.update()
|
||||
await self.start()
|
||||
self.log.debug(f"Primary user switched to {self.client.id}")
|
||||
return True
|
||||
|
||||
async def update_type(self, type: str) -> bool:
|
||||
if not type or type == self.type:
|
||||
async def update_type(self, type: str | None) -> bool:
|
||||
if type is None or type == self.type:
|
||||
return True
|
||||
try:
|
||||
loader = PluginLoader.find(type)
|
||||
except KeyError:
|
||||
return False
|
||||
await self.stop()
|
||||
self.db_instance.type = loader.meta.id
|
||||
self.type = loader.meta.id
|
||||
if self.loader:
|
||||
self.loader.references.remove(self)
|
||||
self.loader = loader
|
||||
self.loader.references.add(self)
|
||||
await self.update()
|
||||
await self.start()
|
||||
self.log.debug(f"Type switched to {self.loader.meta.id}")
|
||||
return True
|
||||
|
@ -303,39 +309,41 @@ class PluginInstance:
|
|||
if started is not None and started != self.started:
|
||||
await (self.start() if started else self.stop())
|
||||
|
||||
def update_enabled(self, enabled: bool) -> None:
|
||||
async def update_enabled(self, enabled: bool) -> None:
|
||||
if enabled is not None and enabled != self.enabled:
|
||||
self.db_instance.enabled = enabled
|
||||
self.enabled = enabled
|
||||
await self.update()
|
||||
|
||||
# region Properties
|
||||
@classmethod
|
||||
@async_getter_lock
|
||||
async def get(
|
||||
cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
|
||||
) -> PluginInstance | None:
|
||||
try:
|
||||
return cls.cache[instance_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.db_instance.id
|
||||
instance = cast(cls, await super().get(instance_id))
|
||||
if instance is not None:
|
||||
instance.postinit()
|
||||
return instance
|
||||
|
||||
@id.setter
|
||||
def id(self, value: str) -> None:
|
||||
self.db_instance.id = value
|
||||
if type and primary_user:
|
||||
instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user)
|
||||
await instance.insert()
|
||||
instance.postinit()
|
||||
return instance
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.db_instance.type
|
||||
return None
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
|
||||
@property
|
||||
def primary_user(self) -> UserID:
|
||||
return self.db_instance.primary_user
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def init(
|
||||
config: Config, webserver: MaubotServer, loop: AbstractEventLoop
|
||||
) -> Iterable[PluginInstance]:
|
||||
PluginInstance.mb_config = config
|
||||
PluginInstance.loop = loop
|
||||
PluginInstance.webserver = webserver
|
||||
return PluginInstance.all()
|
||||
@classmethod
|
||||
async def all(cls) -> AsyncGenerator[PluginInstance, None]:
|
||||
instances = await super().all()
|
||||
instance: PluginInstance
|
||||
for instance in instances:
|
||||
try:
|
||||
yield cls.cache[instance.id]
|
||||
except KeyError:
|
||||
instance.postinit()
|
||||
yield instance
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue