mirror of
https://github.com/maubot/maubot
synced 2025-08-29 17:50:38 +00:00
Add support for end-to-end encryption. Fixes #46
This commit is contained in:
parent
4e767a10e4
commit
69d7a4341b
17 changed files with 203 additions and 24 deletions
|
@ -57,7 +57,7 @@ log.info(f"Initializing maubot {__version__}")
|
|||
|
||||
init_zip_loader(config)
|
||||
db_engine = init_db(config)
|
||||
clients = init_client_class(loop)
|
||||
clients = init_client_class(config, loop)
|
||||
management_api = init_mgmt_api(config, loop)
|
||||
server = MaubotServer(management_api, config, loop)
|
||||
plugins = init_plugin_instance_class(config, server, loop)
|
||||
|
@ -72,6 +72,9 @@ signal.signal(signal.SIGTERM, signal.default_int_handler)
|
|||
try:
|
||||
log.info("Starting server")
|
||||
loop.run_until_complete(server.start())
|
||||
if Client.crypto_db:
|
||||
log.debug("Starting client crypto database")
|
||||
loop.run_until_complete(Client.crypto_db.start())
|
||||
log.info("Starting clients and plugins")
|
||||
loop.run_until_complete(asyncio.gather(*[client.start() for client in clients]))
|
||||
log.info("Startup actions complete, running forever")
|
||||
|
|
|
@ -18,12 +18,13 @@ from io import BytesIO
|
|||
import zipfile
|
||||
import os
|
||||
|
||||
from mautrix.client.api.types.util import SerializerError
|
||||
from ruamel.yaml import YAML, YAMLError
|
||||
from colorama import Fore
|
||||
from PyInquirer import prompt
|
||||
import click
|
||||
|
||||
from mautrix.types import SerializerError
|
||||
|
||||
from ...loader import PluginMeta
|
||||
from ..cliq.validators import PathValidator
|
||||
from ..base import app
|
||||
|
|
|
@ -18,9 +18,10 @@ import asyncio
|
|||
|
||||
from colorama import Fore
|
||||
from aiohttp import WSMsgType, WSMessage, ClientSession
|
||||
from mautrix.client.api.types.util import Obj
|
||||
import click
|
||||
|
||||
from mautrix.types import Obj
|
||||
|
||||
from ..config import get_token
|
||||
from ..base import app
|
||||
|
||||
|
|
|
@ -13,24 +13,46 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
|
||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
|
||||
from os import path
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from yarl import URL
|
||||
|
||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
|
||||
PresenceState, StateFilter)
|
||||
from mautrix.client import InternalEventType
|
||||
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
||||
|
||||
from .lib.store_proxy import ClientStoreProxy
|
||||
from .lib.store_proxy import SyncStoreProxy
|
||||
from .db import DBClient
|
||||
from .matrix import MaubotMatrixClient
|
||||
|
||||
try:
|
||||
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
|
||||
PickleCryptoStore)
|
||||
|
||||
|
||||
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
||||
pass
|
||||
except ImportError:
|
||||
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
|
||||
SQLStateStore = BaseSQLStateStore
|
||||
|
||||
try:
|
||||
from mautrix.util.async_db import Database as AsyncDatabase
|
||||
from mautrix.crypto import PgCryptoStore
|
||||
except ImportError:
|
||||
AsyncDatabase = None
|
||||
PgCryptoStore = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .instance import PluginInstance
|
||||
from .config import Config
|
||||
|
||||
log = logging.getLogger("maubot.client")
|
||||
|
||||
|
@ -40,10 +62,15 @@ class Client:
|
|||
loop: asyncio.AbstractEventLoop = None
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
http_client: ClientSession = None
|
||||
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
|
||||
crypto_pickle_dir: str = None
|
||||
crypto_db: 'AsyncDatabase' = None
|
||||
|
||||
references: Set['PluginInstance']
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
crypto: Optional['OlmMachine']
|
||||
crypto_store: Optional['CryptoStore']
|
||||
started: bool
|
||||
|
||||
remote_displayname: Optional[str]
|
||||
|
@ -61,7 +88,15 @@ class Client:
|
|||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||
token=self.access_token, client_session=self.http_client,
|
||||
log=self.log, loop=self.loop, device_id=self.device_id,
|
||||
store=ClientStoreProxy(self.db_instance))
|
||||
sync_store=SyncStoreProxy(self.db_instance),
|
||||
state_store=self.global_state_store)
|
||||
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
|
||||
self.crypto_store = self._make_crypto_store()
|
||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||
self.client.crypto = self.crypto
|
||||
else:
|
||||
self.crypto_store = None
|
||||
self.crypto = None
|
||||
self.client.ignore_initial_sync = True
|
||||
self.client.ignore_first_sync = True
|
||||
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
|
||||
|
@ -71,6 +106,14 @@ class Client:
|
|||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||
|
||||
def _make_crypto_store(self) -> 'CryptoStore':
|
||||
if self.crypto_db:
|
||||
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
|
||||
elif self.crypto_pickle_dir:
|
||||
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
|
||||
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
|
||||
raise ValueError("Crypto database not configured")
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: Dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
@ -130,6 +173,16 @@ class Client:
|
|||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
if self.crypto:
|
||||
self.log.debug("Enabling end-to-end encryption support")
|
||||
await self.crypto_store.open()
|
||||
crypto_device_id = await self.crypto_store.get_device_id()
|
||||
if crypto_device_id and crypto_device_id != self.device_id:
|
||||
self.log.warning("Mismatching device ID in crypto store and main database. "
|
||||
"Encryption may not work.")
|
||||
await self.crypto.load()
|
||||
if not crypto_device_id:
|
||||
await self.crypto_store.put_device_id(self.device_id)
|
||||
self.start_sync()
|
||||
await self._update_remote_profile()
|
||||
self.started = True
|
||||
|
@ -154,6 +207,8 @@ class Client:
|
|||
self.started = False
|
||||
await self.stop_plugins()
|
||||
self.stop_sync()
|
||||
if self.crypto:
|
||||
await self.crypto_store.close()
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
self.stop_sync()
|
||||
|
@ -172,6 +227,7 @@ class Client:
|
|||
"id": self.id,
|
||||
"homeserver": self.homeserver,
|
||||
"access_token": self.access_token,
|
||||
"device_id": self.device_id,
|
||||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"sync": self.sync,
|
||||
|
@ -243,11 +299,12 @@ class Client:
|
|||
return
|
||||
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token, loop=self.loop,
|
||||
client_session=self.http_client, log=self.log)
|
||||
client_session=self.http_client, device_id=self.device_id,
|
||||
log=self.log, state_store=self.global_state_store)
|
||||
mxid = await new_client.whoami()
|
||||
if mxid != self.id:
|
||||
raise ValueError(f"MXID mismatch: {mxid}")
|
||||
new_client.store = self.db_instance
|
||||
new_client.sync_store = self.db_instance
|
||||
self.stop_sync()
|
||||
self.client = new_client
|
||||
self.db_instance.homeserver = homeserver
|
||||
|
@ -341,7 +398,30 @@ class Client:
|
|||
# endregion
|
||||
|
||||
|
||||
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
|
||||
if OlmMachine:
|
||||
db_type = config["crypto_database.type"]
|
||||
if db_type == "default":
|
||||
db_url = config["database"]
|
||||
parsed_url = URL(db_url)
|
||||
if parsed_url.scheme == "sqlite":
|
||||
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
||||
elif parsed_url.scheme == "postgres":
|
||||
if not PgCryptoStore:
|
||||
log.warning("Default database is postgres, but asyncpg is not installed. "
|
||||
"Encryption will not work.")
|
||||
else:
|
||||
Client.crypto_db = AsyncDatabase(url=db_url,
|
||||
upgrade_table=PgCryptoStore.upgrade_table)
|
||||
elif db_type == "pickle":
|
||||
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
||||
elif db_type == "postgres" and PgCryptoStore:
|
||||
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
|
||||
upgrade_table=PgCryptoStore.upgrade_table)
|
||||
else:
|
||||
raise ValueError("Unsupported crypto database type")
|
||||
|
||||
return Client.all()
|
||||
|
|
|
@ -32,6 +32,9 @@ class Config(BaseFileConfig):
|
|||
base = helper.base
|
||||
copy = helper.copy
|
||||
copy("database")
|
||||
copy("crypto_database.type")
|
||||
copy("crypto_database.postgres_uri")
|
||||
copy("crypto_database.pickle_dir")
|
||||
copy("plugin_directories.upload")
|
||||
copy("plugin_directories.load")
|
||||
copy("plugin_directories.trash")
|
||||
|
|
|
@ -23,6 +23,7 @@ import sqlalchemy as sql
|
|||
|
||||
from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI
|
||||
from mautrix.util.db import Base
|
||||
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
|
||||
|
||||
from .config import Config
|
||||
|
||||
|
@ -79,7 +80,7 @@ def init(config: Config) -> Engine:
|
|||
db = sql.create_engine(config["database"])
|
||||
Base.metadata.bind = db
|
||||
|
||||
for table in (DBPlugin, DBClient):
|
||||
for table in (DBPlugin, DBClient, RoomState, UserProfile):
|
||||
table.bind(db)
|
||||
|
||||
if not db.has_table("alembic_version"):
|
||||
|
|
|
@ -13,11 +13,11 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from mautrix.client import ClientStore
|
||||
from mautrix.client import SyncStore
|
||||
from mautrix.types import SyncToken
|
||||
|
||||
|
||||
class ClientStoreProxy(ClientStore):
|
||||
class SyncStoreProxy(SyncStore):
|
||||
def __init__(self, db_instance) -> None:
|
||||
self.db_instance = db_instance
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@ import asyncio
|
|||
|
||||
from attr import dataclass
|
||||
from packaging.version import Version, InvalidVersion
|
||||
from mautrix.client.api.types.util import (SerializableAttrs, SerializerError, serializer,
|
||||
deserializer)
|
||||
|
||||
from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer
|
||||
|
||||
from ..__meta__ import __version__
|
||||
from ..plugin_base import Plugin
|
||||
|
|
|
@ -22,7 +22,8 @@ import os
|
|||
|
||||
from ruamel.yaml import YAML, YAMLError
|
||||
from packaging.version import Version
|
||||
from mautrix.client.api.types.util import SerializerError
|
||||
|
||||
from mautrix.types import SerializerError
|
||||
|
||||
from ..lib.zipimport import zipimporter, ZipImportError
|
||||
from ..plugin_base import Plugin
|
||||
|
|
|
@ -13,13 +13,14 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Union, Awaitable, Optional, Tuple
|
||||
from typing import Union, Awaitable, Optional, Tuple, List
|
||||
from html import escape
|
||||
import asyncio
|
||||
|
||||
import attr
|
||||
|
||||
from mautrix.client import Client as MatrixClient, SyncStream
|
||||
from mautrix.util.formatter import parse_html
|
||||
from mautrix.util import markdown
|
||||
from mautrix.util import markdown, formatter
|
||||
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
|
||||
MessageType, TextMessageEventContent, Format, RelatesTo)
|
||||
|
||||
|
@ -32,7 +33,7 @@ def parse_formatted(message: str, allow_html: bool = False, render_markdown: boo
|
|||
html = message
|
||||
else:
|
||||
return message, escape(message)
|
||||
return parse_html(html), html
|
||||
return formatter.parse_html(html), html
|
||||
|
||||
|
||||
class MaubotMessageEvent(MessageEvent):
|
||||
|
@ -110,12 +111,12 @@ class MaubotMatrixClient(MatrixClient):
|
|||
content.set_edit(edits)
|
||||
return self.send_message(room_id, content, **kwargs)
|
||||
|
||||
async def dispatch_event(self, event: Event, source: SyncStream = SyncStream.INTERNAL) -> None:
|
||||
def dispatch_event(self, event: Event, source: SyncStream) -> List[asyncio.Task]:
|
||||
if isinstance(event, MessageEvent):
|
||||
event = MaubotMessageEvent(event, self)
|
||||
elif source != SyncStream.INTERNAL:
|
||||
event.client = self
|
||||
return await super().dispatch_event(event, source)
|
||||
return super().dispatch_event(event, source)
|
||||
|
||||
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
|
||||
event = await super().get_event(room_id, event_id)
|
||||
|
|
|
@ -36,7 +36,7 @@ from .config import Config
|
|||
from ..plugin_base import Plugin
|
||||
from ..loader import PluginMeta
|
||||
from ..matrix import MaubotMatrixClient
|
||||
from ..lib.store_proxy import ClientStoreProxy
|
||||
from ..lib.store_proxy import SyncStoreProxy
|
||||
from ..__meta__ import __version__
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -143,7 +143,7 @@ async def main():
|
|||
global client, bot
|
||||
|
||||
client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token,
|
||||
client_session=http_client, loop=loop, store=ClientStoreProxy(nb),
|
||||
client_session=http_client, loop=loop, store=SyncStoreProxy(nb),
|
||||
log=logging.getLogger("maubot.client").getChild(user_id))
|
||||
|
||||
while True:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue