Stop using SQLAlchemy ORM and add colorful logs

This commit is contained in:
Tulir Asokan 2019-09-01 14:46:08 +03:00
parent 59998b99b1
commit b59eab2953
8 changed files with 90 additions and 65 deletions

View file

@ -13,22 +13,19 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import cast
from typing import Iterable, Optional
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.engine.base import Engine
import sqlalchemy as sql
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
from mautrix.bridge.db import Base
from .config import Config
Base: declarative_base = declarative_base()
class DBPlugin(Base):
query: Query
__tablename__ = "plugin"
id: str = Column(String(255), primary_key=True)
@ -39,9 +36,16 @@ class DBPlugin(Base):
nullable=False)
config: str = Column(Text, nullable=False, default='')
@classmethod
def all(cls) -> Iterable['DBPlugin']:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional['DBPlugin']:
return cls._select_one_or_none(cls.c.id == id)
class DBClient(Base):
query: Query
__tablename__ = "client"
id: UserID = Column(String(255), primary_key=True)
@ -58,15 +62,23 @@ class DBClient(Base):
displayname: str = Column(String(255), nullable=False, default="")
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
@classmethod
def all(cls) -> Iterable['DBClient']:
return cls._select_all()
def init(config: Config) -> Session:
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
db_factory = sessionmaker(bind=db_engine)
db_session = scoped_session(db_factory)
@classmethod
def get(cls, id: str) -> Optional['DBClient']:
return cls._select_one_or_none(cls.c.id == id)
def init(config: Config) -> Engine:
db_engine = sql.create_engine(config["database"])
Base.metadata.bind = db_engine
Base.metadata.create_all()
DBPlugin.query = db_session.query_property()
DBClient.query = db_session.query_property()
for table in (DBPlugin, DBClient):
table.db = db_engine
table.t = table.__table__
table.c = table.t.c
table.column_names = table.c.keys()
return cast(Session, db_session)
return db_engine