mirror of
https://github.com/maubot/maubot
synced 2025-09-02 00:00:39 +00:00
Refactor things and implement instance API
This commit is contained in:
parent
cbeff0c0cb
commit
bc87b2a02b
14 changed files with 249 additions and 100 deletions
|
@ -14,10 +14,12 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Optional, Set, TYPE_CHECKING
|
||||
from aiohttp import ClientSession
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||
EventType, Filter, RoomFilter, RoomEventFilter)
|
||||
|
@ -32,6 +34,7 @@ log = logging.getLogger("maubot.client")
|
|||
|
||||
|
||||
class Client:
|
||||
db: Session = None
|
||||
log: logging.Logger = None
|
||||
loop: asyncio.AbstractEventLoop = None
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
|
@ -73,12 +76,12 @@ class Client:
|
|||
user_id = await self.client.whoami()
|
||||
except MatrixInvalidToken as e:
|
||||
self.log.error(f"Invalid token: {e}. Disabling client")
|
||||
self.enabled = False
|
||||
self.db_instance.enabled = False
|
||||
return
|
||||
except MatrixRequestError:
|
||||
if try_n >= 5:
|
||||
self.log.exception("Failed to get /account/whoami, disabling client")
|
||||
self.enabled = False
|
||||
self.db_instance.enabled = False
|
||||
else:
|
||||
self.log.exception(f"Failed to get /account/whoami, "
|
||||
f"retrying in {(try_n + 1) * 10}s")
|
||||
|
@ -86,7 +89,7 @@ class Client:
|
|||
return
|
||||
if user_id != self.id:
|
||||
self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}")
|
||||
self.enabled = False
|
||||
self.db_instance.enabled = False
|
||||
return
|
||||
if not self.filter_id:
|
||||
self.filter_id = await self.client.create_filter(Filter(
|
||||
|
@ -100,8 +103,7 @@ class Client:
|
|||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
if self.sync:
|
||||
self.client.start(self.filter_id)
|
||||
self.start_sync()
|
||||
self.started = True
|
||||
self.log.info("Client started, starting plugin instances...")
|
||||
await self.start_plugins()
|
||||
|
@ -110,12 +112,19 @@ class Client:
|
|||
await asyncio.gather(*[plugin.start() for plugin in self.references], loop=self.loop)
|
||||
|
||||
async def stop_plugins(self) -> None:
|
||||
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.running],
|
||||
await asyncio.gather(*[plugin.stop() for plugin in self.references if plugin.started],
|
||||
loop=self.loop)
|
||||
|
||||
def start_sync(self) -> None:
|
||||
if self.sync:
|
||||
self.client.start(self.filter_id)
|
||||
|
||||
def stop_sync(self) -> None:
|
||||
self.client.stop()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.started = False
|
||||
self.client.stop()
|
||||
self.stop_sync()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
@ -233,7 +242,8 @@ class Client:
|
|||
# endregion
|
||||
|
||||
|
||||
def init(loop: asyncio.AbstractEventLoop) -> List[Client]:
|
||||
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
|
||||
Client.db = db
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
return Client.all()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue