mirror of
https://github.com/maubot/maubot
synced 2025-09-02 00:00:39 +00:00
Add support for end-to-end encryption. Fixes #46
This commit is contained in:
parent
4e767a10e4
commit
69d7a4341b
17 changed files with 203 additions and 24 deletions
|
@ -13,24 +13,46 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
|
||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
|
||||
from os import path
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from yarl import URL
|
||||
|
||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
|
||||
PresenceState, StateFilter)
|
||||
from mautrix.client import InternalEventType
|
||||
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
|
||||
|
||||
from .lib.store_proxy import ClientStoreProxy
|
||||
from .lib.store_proxy import SyncStoreProxy
|
||||
from .db import DBClient
|
||||
from .matrix import MaubotMatrixClient
|
||||
|
||||
try:
|
||||
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
|
||||
PickleCryptoStore)
|
||||
|
||||
|
||||
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
|
||||
pass
|
||||
except ImportError:
|
||||
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
|
||||
SQLStateStore = BaseSQLStateStore
|
||||
|
||||
try:
|
||||
from mautrix.util.async_db import Database as AsyncDatabase
|
||||
from mautrix.crypto import PgCryptoStore
|
||||
except ImportError:
|
||||
AsyncDatabase = None
|
||||
PgCryptoStore = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .instance import PluginInstance
|
||||
from .config import Config
|
||||
|
||||
log = logging.getLogger("maubot.client")
|
||||
|
||||
|
@ -40,10 +62,15 @@ class Client:
|
|||
loop: asyncio.AbstractEventLoop = None
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
http_client: ClientSession = None
|
||||
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
|
||||
crypto_pickle_dir: str = None
|
||||
crypto_db: 'AsyncDatabase' = None
|
||||
|
||||
references: Set['PluginInstance']
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
crypto: Optional['OlmMachine']
|
||||
crypto_store: Optional['CryptoStore']
|
||||
started: bool
|
||||
|
||||
remote_displayname: Optional[str]
|
||||
|
@ -61,7 +88,15 @@ class Client:
|
|||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||
token=self.access_token, client_session=self.http_client,
|
||||
log=self.log, loop=self.loop, device_id=self.device_id,
|
||||
store=ClientStoreProxy(self.db_instance))
|
||||
sync_store=SyncStoreProxy(self.db_instance),
|
||||
state_store=self.global_state_store)
|
||||
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
|
||||
self.crypto_store = self._make_crypto_store()
|
||||
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
|
||||
self.client.crypto = self.crypto
|
||||
else:
|
||||
self.crypto_store = None
|
||||
self.crypto = None
|
||||
self.client.ignore_initial_sync = True
|
||||
self.client.ignore_first_sync = True
|
||||
self.client.presence = PresenceState.ONLINE if self.online else PresenceState.OFFLINE
|
||||
|
@ -71,6 +106,14 @@ class Client:
|
|||
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
|
||||
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
|
||||
|
||||
def _make_crypto_store(self) -> 'CryptoStore':
|
||||
if self.crypto_db:
|
||||
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
|
||||
elif self.crypto_pickle_dir:
|
||||
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
|
||||
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
|
||||
raise ValueError("Crypto database not configured")
|
||||
|
||||
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
|
||||
async def handler(data: Dict[str, Any]) -> None:
|
||||
self.sync_ok = ok
|
||||
|
@ -130,6 +173,16 @@ class Client:
|
|||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
if self.crypto:
|
||||
self.log.debug("Enabling end-to-end encryption support")
|
||||
await self.crypto_store.open()
|
||||
crypto_device_id = await self.crypto_store.get_device_id()
|
||||
if crypto_device_id and crypto_device_id != self.device_id:
|
||||
self.log.warning("Mismatching device ID in crypto store and main database. "
|
||||
"Encryption may not work.")
|
||||
await self.crypto.load()
|
||||
if not crypto_device_id:
|
||||
await self.crypto_store.put_device_id(self.device_id)
|
||||
self.start_sync()
|
||||
await self._update_remote_profile()
|
||||
self.started = True
|
||||
|
@ -154,6 +207,8 @@ class Client:
|
|||
self.started = False
|
||||
await self.stop_plugins()
|
||||
self.stop_sync()
|
||||
if self.crypto:
|
||||
await self.crypto_store.close()
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
self.stop_sync()
|
||||
|
@ -172,6 +227,7 @@ class Client:
|
|||
"id": self.id,
|
||||
"homeserver": self.homeserver,
|
||||
"access_token": self.access_token,
|
||||
"device_id": self.device_id,
|
||||
"enabled": self.enabled,
|
||||
"started": self.started,
|
||||
"sync": self.sync,
|
||||
|
@ -243,11 +299,12 @@ class Client:
|
|||
return
|
||||
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
|
||||
token=access_token or self.access_token, loop=self.loop,
|
||||
client_session=self.http_client, log=self.log)
|
||||
client_session=self.http_client, device_id=self.device_id,
|
||||
log=self.log, state_store=self.global_state_store)
|
||||
mxid = await new_client.whoami()
|
||||
if mxid != self.id:
|
||||
raise ValueError(f"MXID mismatch: {mxid}")
|
||||
new_client.store = self.db_instance
|
||||
new_client.sync_store = self.db_instance
|
||||
self.stop_sync()
|
||||
self.client = new_client
|
||||
self.db_instance.homeserver = homeserver
|
||||
|
@ -341,7 +398,30 @@ class Client:
|
|||
# endregion
|
||||
|
||||
|
||||
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
|
||||
if OlmMachine:
|
||||
db_type = config["crypto_database.type"]
|
||||
if db_type == "default":
|
||||
db_url = config["database"]
|
||||
parsed_url = URL(db_url)
|
||||
if parsed_url.scheme == "sqlite":
|
||||
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
||||
elif parsed_url.scheme == "postgres":
|
||||
if not PgCryptoStore:
|
||||
log.warning("Default database is postgres, but asyncpg is not installed. "
|
||||
"Encryption will not work.")
|
||||
else:
|
||||
Client.crypto_db = AsyncDatabase(url=db_url,
|
||||
upgrade_table=PgCryptoStore.upgrade_table)
|
||||
elif db_type == "pickle":
|
||||
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
|
||||
elif db_type == "postgres" and PgCryptoStore:
|
||||
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
|
||||
upgrade_table=PgCryptoStore.upgrade_table)
|
||||
else:
|
||||
raise ValueError("Unsupported crypto database type")
|
||||
|
||||
return Client.all()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue