Black everything

This commit is contained in:
Felipe Martin 2020-12-13 08:46:09 +01:00
parent 0b06098e6b
commit c8c5fefabc
Signed by: fmartingr
GPG Key ID: 716BC147715E716F
13 changed files with 72 additions and 30 deletions

View File

@ -3,7 +3,16 @@ import os.path
from functools import wraps
import structlog
from flask import Blueprint, render_template, request, session, redirect, url_for, flash, g
from flask import (
Blueprint,
render_template,
request,
session,
redirect,
url_for,
flash,
g,
)
from butterrobot.config import HOSTNAME
from butterrobot.db import UserQuery, ChannelQuery, ChannelPluginQuery
@ -19,14 +28,15 @@ def login_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
if g.user is None:
return redirect(url_for('admin.login_view', next=request.path))
return redirect(url_for("admin.login_view", next=request.path))
return f(*args, **kwargs)
return decorated_function
@admin.before_app_request
def load_logged_in_user():
user_id = session.get('user_id')
user_id = session.get("user_id")
if user_id is None:
g.user = None
@ -91,8 +101,7 @@ def channel_list_view():
def channel_detail_view(channel_id):
if request.method == "POST":
ChannelQuery.update(
channel_id,
enabled=request.form["enabled"] == "true",
channel_id, enabled=request.form["enabled"] == "true",
)
flash("Channel updated", "success")
@ -129,13 +138,13 @@ def channel_plugin_list_view():
def channel_plugin_detail_view(channel_plugin_id):
if request.method == "POST":
ChannelPluginQuery.update(
channel_plugin_id,
enabled=request.form["enabled"] == "true",
channel_plugin_id, enabled=request.form["enabled"] == "true",
)
flash("Plugin updated", category="success")
return redirect(request.headers.get("Referer"))
@admin.route("/channelplugins/<channel_plugin_id>/delete", methods=["POST"])
@login_required
def channel_plugin_delete_view(channel_plugin_id):

View File

@ -23,13 +23,14 @@ class ExternalProxyFix(object):
used by one of the reverse proxies in front of this in production.
It does nothing if the header is not present.
"""
def __init__(self, app):
self.app = app
def __call__(self, environ, start_response):
host = environ.get('HTTP_X_EXTERNAL_HOST', '')
host = environ.get("HTTP_X_EXTERNAL_HOST", "")
if host:
environ['HTTP_HOST'] = host
environ["HTTP_HOST"] = host
return self.app(environ, start_response)
@ -47,10 +48,12 @@ def incoming_platform_message_view(platform, path=None):
if platform not in get_available_platforms():
return {"error": "Unknown platform"}, 400
q.put({"platform": platform, "request": {
"path": request.path,
"json": request.get_json()
}})
q.put(
{
"platform": platform,
"request": {"path": request.path, "json": request.get_json()},
}
)
return {}

View File

@ -3,7 +3,9 @@ import os
# --- Butter Robot -----------------------------------------------------------------
DEBUG = os.environ.get("DEBUG", "n") == "y"
HOSTNAME = os.environ.get("BUTTERROBOT_HOSTNAME", "butterrobot-dev.int.fmartingr.network")
HOSTNAME = os.environ.get(
"BUTTERROBOT_HOSTNAME", "butterrobot-dev.int.fmartingr.network"
)
LOG_LEVEL = os.environ.get("LOG_LEVEL", "ERROR")

View File

@ -10,7 +10,6 @@ from butterrobot.objects import Channel, ChannelPlugin, User
db = dataset.connect(DATABASE_PATH)
class Query:
class NotFound(Exception):
pass
@ -27,7 +26,7 @@ class Query:
yield cls.obj(**row)
@classmethod
def get(cls, **kwargs) -> 'class':
def get(cls, **kwargs) -> "class":
"""
Returns the object representation of an specific row in a table.
Allows retrieving object by multiple columns.
@ -61,7 +60,7 @@ class Query:
@classmethod
def update(cls, row_id, **fields):
fields.update({"id": row_id})
return db[cls.tablename].update(fields, ("id", ))
return db[cls.tablename].update(fields, ("id",))
@classmethod
def delete(cls, id):
@ -79,7 +78,7 @@ class UserQuery(Query):
).hex()
@classmethod
def check_credentials(cls, username, password) -> Union[User, 'False']:
def check_credentials(cls, username, password) -> Union[User, "False"]:
user = db[cls.tablename].find_one(username=username)
if user:
hash_password = cls._hash_password(password)
@ -125,7 +124,9 @@ class ChannelQuery(Query):
plugins = ChannelPluginQuery.get_from_channel_id(result["id"])
return cls.obj(plugins={plugin.plugin_id: plugin for plugin in plugins}, **result)
return cls.obj(
plugins={plugin.plugin_id: plugin for plugin in plugins}, **result
)
@classmethod
def delete(cls, _id):

View File

@ -51,8 +51,8 @@ class TelegramAPI:
"disable_notification": disable_notification,
"reply_to_message_id": reply_to_message_id,
}
response = requests.post(url, json=payload)
response_json = response.json()
if not response_json["ok"]:
raise cls.TelegramClientError(response_json)
raise cls.TelegramClientError(response_json)

View File

@ -14,7 +14,9 @@ structlog.configure(
structlog.processors.StackInfoRenderer(),
structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M.%S"),
structlog.processors.format_exc_info,
structlog.dev.ConsoleRenderer() if DEBUG else structlog.processors.JSONRenderer(),
structlog.dev.ConsoleRenderer()
if DEBUG
else structlog.processors.JSONRenderer(),
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),

View File

@ -36,6 +36,7 @@ class Channel:
@property
def channel_name(self):
from butterrobot.platforms import PLATFORMS
return PLATFORMS[self.platform].parse_channel_name_from_raw(self.channel_raw)

View File

@ -8,12 +8,16 @@ from butterrobot.platforms.debug import DebugPlatform
logger = structlog.get_logger(__name__)
PLATFORMS = {platform.ID: platform for platform in (SlackPlatform, TelegramPlatform, DebugPlatform)}
PLATFORMS = {
platform.ID: platform
for platform in (SlackPlatform, TelegramPlatform, DebugPlatform)
}
@lru_cache
def get_available_platforms():
from butterrobot.platforms import PLATFORMS
available_platforms = {}
for platform in PLATFORMS.values():
logger.debug("Setting up", platform=platform.ID)

View File

@ -17,6 +17,7 @@ class Platform:
"""
Used when the platform needs to make a response right away instead of async.
"""
data: dict
status_code: int = 200

View File

@ -35,6 +35,10 @@ class DebugPlatform(Platform):
from_bot=bool(request_data.get("from_bot", False)),
author=request_data.get("author", "Debug author"),
chat=request_data.get("chat", "Debug chat ID"),
channel=Channel(platform=cls.ID, platform_channel_id=request_data.get("chat"), channel_raw={}),
channel=Channel(
platform=cls.ID,
platform_channel_id=request_data.get("chat"),
channel_raw={},
),
raw={},
)

View File

@ -80,6 +80,8 @@ class TelegramPlatform(Platform):
from_bot=request["json"]["message"]["from"]["is_bot"],
author=request["json"]["message"]["from"]["id"],
chat=str(request["json"]["message"]["chat"]["id"]),
channel=cls.parse_channel_from_message(request["json"]["message"]["chat"]),
channel=cls.parse_channel_from_message(
request["json"]["message"]["chat"]
),
raw=request["json"],
)

View File

@ -15,7 +15,9 @@ q = queue.Queue()
def handle_message(platform: str, request: dict):
try:
message = get_available_platforms()[platform].parse_incoming_message(request=request)
message = get_available_platforms()[platform].parse_incoming_message(
request=request
)
except Platform.PlatformAuthResponse as response:
return response.data, response.status_code
except Exception as error:
@ -36,7 +38,9 @@ def handle_message(platform: str, request: dict):
channel = ChannelQuery.get_by_platform(platform, message.chat)
except ChannelQuery.NotFound:
# If channel is still not present on the database, create it (defaults to disabled)
channel = ChannelQuery.create(platform, message.chat, channel_raw=message.channel.channel_raw)
channel = ChannelQuery.create(
platform, message.chat, channel_raw=message.channel.channel_raw
)
if not channel.enabled:
return
@ -45,7 +49,9 @@ def handle_message(platform: str, request: dict):
if not channel.has_enabled_plugin(plugin_id):
continue
for response_message in get_available_plugins()[plugin_id].on_message(message, plugin_config=channel_plugin.config):
for response_message in get_available_plugins()[plugin_id].on_message(
message, plugin_config=channel_plugin.config
):
get_available_platforms()[platform].methods.send_message(response_message)
@ -55,5 +61,6 @@ def worker_thread():
handle_message(item["platform"], item["request"])
q.task_done()
# turn-on the worker thread
worker = threading.Thread(target=worker_thread, daemon=True).start()

View File

@ -17,7 +17,9 @@ class LoquitoPlugin(Plugin):
@classmethod
def on_message(cls, message, **kwargs):
if "lo quito" in message.text.lower():
yield Message(chat=message.chat, reply_to=message.id, text="Loquito tu.",)
yield Message(
chat=message.chat, reply_to=message.id, text="Loquito tu.",
)
class DicePlugin(Plugin):
@ -42,4 +44,8 @@ class CoinPlugin(Plugin):
@classmethod
def on_message(cls, message: Message, **kwargs):
if message.text.startswith("!coin"):
yield Message(chat=message.chat, reply_to=message.id, text=random.choice(("heads", "tails")))
yield Message(
chat=message.chat,
reply_to=message.id,
text=random.choice(("heads", "tails")),
)