Try another approach for plugin web apps

This commit is contained in:
Tulir Asokan 2019-03-07 19:57:10 +02:00
parent 3c2d0a9fde
commit b3e1f1d4bc
3 changed files with 79 additions and 20 deletions

View file

@ -13,11 +13,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Dict
from typing import Tuple, List, Dict, Callable, Awaitable
from functools import partial
import logging
import asyncio
from aiohttp import web
from aiohttp import web, hdrs, URL
from aiohttp.abc import AbstractAccessLogger
import pkg_resources
@ -34,6 +35,62 @@ class AccessLogger(AbstractAccessLogger):
f'in {round(time, 4)}s"')
Handler = Callable[[web.Request], Awaitable[web.Response]]
Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]]
class PluginWebApp(web.UrlDispatcher):
def __init__(self):
super().__init__()
self._middleware: List[Middleware] = []
def add_middleware(self, middleware: Middleware) -> None:
self._middleware.append(middleware)
def remove_middleware(self, middleware: Middleware) -> None:
self._middleware.remove(middleware)
async def handle(self, request: web.Request) -> web.Response:
match_info = await self.resolve(request)
match_info.freeze()
resp = None
request._match_info = match_info
expect = request.headers.get(hdrs.EXPECT)
if expect:
resp = await match_info.expect_handler(request)
await request.writer.drain()
if resp is None:
handler = match_info.handler
for middleware in self._middleware:
handler = partial(middleware, handler=handler)
resp = await handler(request)
return resp
class PrefixResource(web.Resource):
def __init__(self, prefix, *, name=None):
assert not prefix or prefix.startswith('/'), prefix
assert prefix in ('', '/') or not prefix.endswith('/'), prefix
super().__init__(name=name)
self._prefix = URL.build(path=prefix).raw_path
@property
def canonical(self):
return self._prefix
def add_prefix(self, prefix):
assert prefix.startswith('/')
assert not prefix.endswith('/')
assert len(prefix) > 1
self._prefix = prefix + self._prefix
def _match(self, path: str) -> dict:
return {} if self.raw_match(path) else None
def raw_match(self, path: str) -> bool:
return path and path.startswith(self._prefix)
class MaubotServer:
log: logging.Logger = logging.getLogger("maubot.server")
@ -45,38 +102,38 @@ class MaubotServer:
as_path = PathBuilder(config["server.appservice_base_path"])
self.add_route(Method.PUT, as_path.transactions, self.handle_transaction)
self.plugin_apps: Dict[str, web.Application] = {}
self.app.router.add_view(config["server.plugin_base_path"], self.handle_plugin_path)
self.plugin_routes: Dict[str, PluginWebApp] = {}
resource = PrefixResource(config["server.plugin_base_path"])
resource.add_route(hdrs.METH_ANY, self.handle_plugin_path)
self.app.router.register_resource(resource)
self.setup_management_ui()
self.runner = web.AppRunner(self.app, access_log_class=AccessLogger)
async def handle_plugin_path(self, request: web.Request) -> web.Response:
for path, app in self.plugin_apps.items():
for path, app in self.plugin_routes.items():
if request.path.startswith(path):
# TODO there's probably a correct way to do these
request._rel_url.path = request._rel_url.path[len(path):]
return await app._handle(request)
request = request.clone(rel_url=request.path[len(path):])
return await app.handle(request)
return web.Response(status=404)
def get_instance_subapp(self, instance_id: str) -> Tuple[web.Application, str]:
subpath = self.config["server.plugin_base_path"].format(id=instance_id)
def get_instance_subapp(self, instance_id: str) -> Tuple[PluginWebApp, str]:
subpath = self.config["server.plugin_base_path"] + instance_id
url = self.config["server.public_url"] + subpath
try:
return self.plugin_apps[subpath], url
return self.plugin_routes[subpath], url
except KeyError:
app = web.Application(loop=self.loop)
self.plugin_apps[subpath] = app
app = PluginWebApp()
self.plugin_routes[subpath] = app
return app, url
def remove_instance_webapp(self, instance_id: str) -> None:
try:
subapp: web.Application = self.plugin_apps.pop(instance_id)
subpath = self.config["server.plugin_base_path"] + instance_id
self.plugin_routes.pop(subpath)
except KeyError:
return
subapp.shutdown()
subapp.cleanup()
def setup_management_ui(self) -> None:
ui_base = self.config["server.ui_base_path"]