Stop using SQLAlchemy ORM and add colorful logs

This commit is contained in:
Tulir Asokan 2019-09-01 14:46:08 +03:00
parent 59998b99b1
commit b59eab2953
8 changed files with 90 additions and 65 deletions

View file

@ -13,11 +13,10 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, TYPE_CHECKING
import asyncio
import logging
from sqlalchemy.orm import Session
from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken, MatrixRequestError
@ -35,7 +34,6 @@ log = logging.getLogger("maubot.client")
class Client:
db: Session = None
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
@ -148,9 +146,7 @@ class Client:
def clear_cache(self) -> None:
self.stop_sync()
self.db_instance.filter_id = ""
self.db_instance.next_batch = ""
self.db.commit()
self.db_instance.edit(filter_id="", next_batch="")
self.start_sync()
def delete(self) -> None:
@ -158,8 +154,7 @@ class Client:
del self.cache[self.id]
except KeyError:
pass
self.db.delete(self.db_instance)
self.db.commit()
self.db_instance.delete()
def to_dict(self) -> dict:
return {
@ -183,14 +178,14 @@ class Client:
try:
return cls.cache[user_id]
except KeyError:
db_instance = db_instance or DBClient.query.get(user_id)
db_instance = db_instance or DBClient.get(user_id)
if not db_instance:
return None
return Client(db_instance)
@classmethod
def all(cls) -> List['Client']:
return [cls.get(user.id, user) for user in DBClient.query.all()]
def all(cls) -> Iterable['Client']:
return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_invite(self, evt: StrippedStateEvent) -> None:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
@ -314,8 +309,7 @@ class Client:
# endregion
def init(db: Session, loop: asyncio.AbstractEventLoop) -> List[Client]:
Client.db = db
def init(loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop
return Client.all()