mirror of
https://github.com/maubot/maubot
synced 2025-09-02 00:00:39 +00:00
Stop using SQLAlchemy ORM and add colorful logs
This commit is contained in:
parent
59998b99b1
commit
b59eab2953
8 changed files with 90 additions and 65 deletions
|
@ -13,11 +13,10 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
|
||||
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
|
||||
|
@ -35,7 +34,6 @@ log = logging.getLogger("maubot.client")
|
|||
|
||||
|
||||
class Client:
|
||||
db: Session = None
|
||||
log: logging.Logger = None
|
||||
loop: asyncio.AbstractEventLoop = None
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
|
@ -148,9 +146,7 @@ class Client:
|
|||
|
||||
def clear_cache(self) -> None:
|
||||
self.stop_sync()
|
||||
self.db_instance.filter_id = ""
|
||||
self.db_instance.next_batch = ""
|
||||
self.db.commit()
|
||||
self.db_instance.edit(filter_id="", next_batch="")
|
||||
self.start_sync()
|
||||
|
||||
def delete(self) -> None:
|
||||
|
@ -158,8 +154,7 @@ class Client:
|
|||
del self.cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.db.delete(self.db_instance)
|
||||
self.db.commit()
|
||||
self.db_instance.delete()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
|
@ -183,14 +178,14 @@ class Client:
|
|||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBClient.query.get(user_id)
|
||||
db_instance = db_instance or DBClient.get(user_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return Client(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> List['Client']:
|
||||
return [cls.get(user.id, user) for user in DBClient.query.all()]
|
||||
def all(cls) -> Iterable['Client']:
|
||||
return (cls.get(user.id, user) for user in DBClient.all())
|
||||
|
||||
async def _handle_invite(self, evt: StrippedStateEvent) -> None:
|
||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
|
@ -314,8 +309,7 @@ class Client:
|
|||
# endregion
|
||||
|
||||
|
||||
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
|
||||
Client.db = db
|
||||
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
return Client.all()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue