mirror of
https://github.com/maubot/maubot
synced 2025-08-29 15:40:37 +00:00
Stop using SQLAlchemy ORM and add colorful logs
This commit is contained in:
parent
59998b99b1
commit
b59eab2953
8 changed files with 90 additions and 65 deletions
42
maubot/db.py
42
maubot/db.py
|
@ -13,22 +13,19 @@
|
|||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import cast
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
|
||||
from sqlalchemy.orm import Query, Session, sessionmaker, scoped_session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.engine.base import Engine
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
|
||||
from mautrix.bridge.db import Base
|
||||
|
||||
from .config import Config
|
||||
|
||||
Base: declarative_base = declarative_base()
|
||||
|
||||
|
||||
class DBPlugin(Base):
|
||||
query: Query
|
||||
__tablename__ = "plugin"
|
||||
|
||||
id: str = Column(String(255), primary_key=True)
|
||||
|
@ -39,9 +36,16 @@ class DBPlugin(Base):
|
|||
nullable=False)
|
||||
config: str = Column(Text, nullable=False, default='')
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable['DBPlugin']:
|
||||
return cls._select_all()
|
||||
|
||||
@classmethod
|
||||
def get(cls, id: str) -> Optional['DBPlugin']:
|
||||
return cls._select_one_or_none(cls.c.id == id)
|
||||
|
||||
|
||||
class DBClient(Base):
|
||||
query: Query
|
||||
__tablename__ = "client"
|
||||
|
||||
id: UserID = Column(String(255), primary_key=True)
|
||||
|
@ -58,15 +62,23 @@ class DBClient(Base):
|
|||
displayname: str = Column(String(255), nullable=False, default="")
|
||||
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> Iterable['DBClient']:
|
||||
return cls._select_all()
|
||||
|
||||
def init(config: Config) -> Session:
|
||||
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
|
||||
db_factory = sessionmaker(bind=db_engine)
|
||||
db_session = scoped_session(db_factory)
|
||||
@classmethod
|
||||
def get(cls, id: str) -> Optional['DBClient']:
|
||||
return cls._select_one_or_none(cls.c.id == id)
|
||||
|
||||
|
||||
def init(config: Config) -> Engine:
|
||||
db_engine = sql.create_engine(config["database"])
|
||||
Base.metadata.bind = db_engine
|
||||
Base.metadata.create_all()
|
||||
|
||||
DBPlugin.query = db_session.query_property()
|
||||
DBClient.query = db_session.query_property()
|
||||
for table in (DBPlugin, DBClient):
|
||||
table.db = db_engine
|
||||
table.t = table.__table__
|
||||
table.c = table.t.c
|
||||
table.column_names = table.c.keys()
|
||||
|
||||
return cast(Session, db_session)
|
||||
return db_engine
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue