Refactor __main__.py and fix things

This commit is contained in:
Tulir Asokan 2021-11-19 15:22:54 +02:00
parent c685eb5e08
commit 7c9668d8bc
9 changed files with 84 additions and 117 deletions

View file

@ -14,17 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
from os import path
import asyncio
import logging
from aiohttp import ClientSession
from yarl import URL
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter)
PresenceState, StateFilter, DeviceID)
from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
@ -33,13 +31,12 @@ from .db import DBClient
from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore,
PickleCryptoStore)
from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, CryptoStore
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
except ImportError:
except ImportError as e:
OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None
SQLStateStore = BaseSQLStateStore
@ -63,8 +60,7 @@ class Client:
cache: Dict[UserID, 'Client'] = {}
http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
crypto_pickle_dir: str = None
crypto_db: 'AsyncDatabase' = None
crypto_db: Optional['AsyncDatabase'] = None
references: Set['PluginInstance']
db_instance: DBClient
@ -90,7 +86,7 @@ class Client:
log=self.log, loop=self.loop, device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store)
if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir):
if OlmMachine and self.device_id and self.crypto_db:
self.crypto_store = self._make_crypto_store()
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
@ -109,9 +105,6 @@ class Client:
def _make_crypto_store(self) -> 'CryptoStore':
if self.crypto_db:
return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db)
elif self.crypto_pickle_dir:
return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto",
path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle"))
raise ValueError("Crypto database not configured")
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
@ -330,7 +323,7 @@ class Client:
return self.db_instance.access_token
@property
def device_id(self) -> str:
def device_id(self) -> DeviceID:
return self.db_instance.device_id
@property
@ -403,25 +396,9 @@ def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.loop = loop
if OlmMachine:
db_type = config["crypto_database.type"]
if db_type == "default":
db_url = config["crypto_database"]
if db_url == "default":
db_url = config["database"]
parsed_url = URL(db_url)
if parsed_url.scheme == "sqlite":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif parsed_url.scheme == "postgres" or parsed_url.scheme == "postgresql":
if not PgCryptoStore:
log.warning("Default database is postgres, but asyncpg is not installed. "
"Encryption will not work.")
else:
Client.crypto_db = AsyncDatabase(url=db_url,
upgrade_table=PgCryptoStore.upgrade_table)
elif db_type == "pickle":
Client.crypto_pickle_dir = config["crypto_database.pickle_dir"]
elif (db_type == "postgres" or db_type == "postgresql") and PgCryptoStore:
Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"],
upgrade_table=PgCryptoStore.upgrade_table)
else:
raise ValueError("Unsupported crypto database type")
Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
return Client.all()