164 lines
4.4 KiB
Python
164 lines
4.4 KiB
Python
import hashlib
|
|
from typing import Union
|
|
|
|
import dataset
|
|
|
|
from butterrobot.config import SECRET_KEY, DATABASE_PATH
|
|
from butterrobot.objects import User, Channel, ChannelPlugin
|
|
|
|
db = dataset.connect(DATABASE_PATH)
|
|
|
|
|
|
class Query:
|
|
class NotFound(Exception):
|
|
pass
|
|
|
|
class Duplicated(Exception):
|
|
pass
|
|
|
|
@classmethod
|
|
def all(cls):
|
|
"""
|
|
Iterate over all rows on a table.
|
|
"""
|
|
for row in db[cls.tablename].all():
|
|
yield cls.obj(**row)
|
|
|
|
@classmethod
|
|
def get(cls, **kwargs):
|
|
"""
|
|
Returns the object representation of an specific row in a table.
|
|
Allows retrieving object by multiple columns.
|
|
Raises `NotFound` error if query return no results.
|
|
"""
|
|
row = db[cls.tablename].find_one(**kwargs)
|
|
if not row:
|
|
raise cls.NotFound
|
|
return cls.obj(**row)
|
|
|
|
@classmethod
|
|
def create(cls, **kwargs):
|
|
"""
|
|
Creates a new row in the table with the provided arguments.
|
|
Returns the row_id
|
|
TODO: Return obj?
|
|
"""
|
|
return db[cls.tablename].insert(kwargs)
|
|
|
|
@classmethod
|
|
def exists(cls, **kwargs) -> bool:
|
|
"""
|
|
Check for the existence of a row with the provided columns.
|
|
"""
|
|
try:
|
|
cls.get(**kwargs)
|
|
except cls.NotFound:
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def update(cls, row_id, **fields):
|
|
fields.update({"id": row_id})
|
|
return db[cls.tablename].update(fields, ("id",))
|
|
|
|
@classmethod
|
|
def delete(cls, id):
|
|
return db[cls.tablename].delete(id=id)
|
|
|
|
|
|
class UserQuery(Query):
|
|
tablename = "users"
|
|
obj = User
|
|
|
|
@classmethod
|
|
def _hash_password(cls, password):
|
|
return hashlib.pbkdf2_hmac(
|
|
"sha256", password.encode("utf-8"), str.encode(SECRET_KEY), 100000
|
|
).hex()
|
|
|
|
@classmethod
|
|
def check_credentials(cls, username, password) -> Union[User, "False"]:
|
|
user = db[cls.tablename].find_one(username=username)
|
|
if user:
|
|
hash_password = cls._hash_password(password)
|
|
if user["password"] == hash_password:
|
|
return cls.obj(**user)
|
|
return False
|
|
|
|
@classmethod
|
|
def create(cls, **kwargs):
|
|
kwargs["password"] = cls._hash_password(kwargs["password"])
|
|
return super().create(**kwargs)
|
|
|
|
|
|
class ChannelQuery(Query):
|
|
tablename = "channels"
|
|
obj = Channel
|
|
|
|
@classmethod
|
|
def create(cls, platform, platform_channel_id, enabled=False, channel_raw={}):
|
|
params = {
|
|
"platform": platform,
|
|
"platform_channel_id": platform_channel_id,
|
|
"enabled": enabled,
|
|
"channel_raw": channel_raw,
|
|
}
|
|
super().create(**params)
|
|
return cls.obj(**params)
|
|
|
|
@classmethod
|
|
def get(cls, _id):
|
|
channel = super().get(id=_id)
|
|
plugins = ChannelPluginQuery.get_from_channel_id(_id)
|
|
channel.plugins = {plugin.plugin_id: plugin for plugin in plugins}
|
|
return channel
|
|
|
|
@classmethod
|
|
def get_by_platform(cls, platform, platform_channel_id):
|
|
result = db[cls.tablename].find_one(
|
|
platform=platform, platform_channel_id=platform_channel_id
|
|
)
|
|
if not result:
|
|
raise cls.NotFound
|
|
|
|
plugins = ChannelPluginQuery.get_from_channel_id(result["id"])
|
|
|
|
return cls.obj(
|
|
plugins={plugin.plugin_id: plugin for plugin in plugins}, **result
|
|
)
|
|
|
|
@classmethod
|
|
def delete(cls, _id):
|
|
ChannelPluginQuery.delete_by_channel(channel_id=_id)
|
|
super().delete(_id)
|
|
|
|
|
|
class ChannelPluginQuery(Query):
|
|
tablename = "channel_plugin"
|
|
obj = ChannelPlugin
|
|
|
|
@classmethod
|
|
def create(cls, channel_id, plugin_id, enabled=False, config={}):
|
|
if cls.exists(channel_id=channel_id, plugin_id=plugin_id):
|
|
raise cls.Duplicated
|
|
|
|
params = {
|
|
"channel_id": channel_id,
|
|
"plugin_id": plugin_id,
|
|
"enabled": enabled,
|
|
"config": config,
|
|
}
|
|
obj_id = super().create(**params)
|
|
return cls.obj(id=obj_id, **params)
|
|
|
|
@classmethod
|
|
def get_from_channel_id(cls, channel_id):
|
|
yield from [
|
|
cls.obj(**row) for row in db[cls.tablename].find(channel_id=channel_id)
|
|
]
|
|
|
|
@classmethod
|
|
def delete_by_channel(cls, channel_id):
|
|
channel_plugins = cls.get_from_channel_id(channel_id)
|
|
[cls.delete(item.id) for item in channel_plugins]
|