Added database tests along with workflows

This commit is contained in:
Felipe Martin 2020-12-11 15:18:20 +01:00
parent b0e82fdefc
commit 3dcad9badf
Signed by: fmartingr
GPG Key ID: 716BC147715E716F
6 changed files with 747 additions and 380 deletions

27
.github/workflows/black.yaml vendored Normal file
View File

@ -0,0 +1,27 @@
name: Black
on:
push:
branches: [ master, stable ]
pull_request:
branches: [ master, stable ]
jobs:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install dependencies
run: |
pip install --upgrade pip
pip install black
- name: Black check
run: |
black --check butterrobot

32
.github/workflows/pytest.yaml vendored Normal file
View File

@ -0,0 +1,32 @@
name: Pytest
on:
push:
branches: [ master, stable ]
pull_request:
branches: [ master, stable ]
jobs:
pytest:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install --upgrade pip poetry
poetry install
- name: Test with pytest
run: |
ls
poetry run pytest --cov=butterrobot

View File

@ -1,5 +1,6 @@
import hashlib
import os
from typing import Union
import dataset
@ -19,14 +20,40 @@ class Query:
@classmethod
def all(cls):
for row in cls._table.all():
yield cls._obj(**row)
"""
Iterate over all rows on a table.
"""
for row in db[cls.tablename].all():
yield cls.obj(**row)
@classmethod
def exists(cls, *args, **kwargs):
def get(cls, **kwargs) -> 'class':
"""
Returns the object representation of an specific row in a table.
Allows retrieving object by multiple columns.
Raises `NotFound` error if query return no results.
"""
row = db[cls.tablename].find_one()
if not row:
raise cls.NotFound
return cls.obj(**row)
@classmethod
def create(cls, **kwargs):
"""
Creates a new row in the table with the provided arguments.
Returns the row_id
TODO: Return obj?
"""
return db[cls.tablename].insert(kwargs)
@classmethod
def exists(cls, **kwargs) -> bool:
"""
Check for the existence of a row with the provided columns.
"""
try:
# Using only *args since those are supposed to be mandatory
cls.get(*args)
cls.get(**kwargs)
except cls.NotFound:
return False
return True
@ -34,27 +61,16 @@ class Query:
@classmethod
def update(cls, row_id, **fields):
fields.update({"id": row_id})
return cls._table.update(fields, ("id", ))
return db[cls.tablename].update(fields, ("id", ))
@classmethod
def get(cls, _id):
row = cls._table.find_one(id=_id)
if not row:
raise cls.NotFound
return cls._obj(**row)
def delete(cls, id):
return db[cls.tablename].delete(id=id)
@classmethod
def update(cls, _id, **fields):
fields.update({"id": _id})
return cls._table.update(fields, ("id"))
@classmethod
def delete(cls, _id):
cls._table.delete(id=_id)
class UserQuery(Query):
_table = db["users"]
_obj = User
tablename = "users"
obj = User
@classmethod
def _hash_password(cls, password):
@ -63,32 +79,23 @@ class UserQuery(Query):
).hex()
@classmethod
def check_credentials(cls, username, password):
user = cls._table.find_one(username=username)
def check_credentials(cls, username, password) -> Union[User, 'False']:
user = db[cls.tablename].find_one(username=username)
if user:
hash_password = cls._hash_password(password)
if user["password"] == hash_password:
return cls._obj(**user)
return cls.obj(**user)
return False
@classmethod
def create(cls, username, password):
hash_password = cls._hash_password(password)
cls._table.insert({"username": username, "password": hash_password})
@classmethod
def delete(cls, username):
return cls._table.delete(username=username)
@classmethod
def update(cls, username, **fields):
fields.update({"username": username})
return cls._table.update(fields, ("username",))
def create(cls, **kwargs):
kwargs["password"] = cls._hash_password(kwargs["password"])
super().create(**kwargs)
class ChannelQuery(Query):
_table = db["channels"]
_obj = Channel
tablename = "channels"
obj = Channel
@classmethod
def create(cls, platform, platform_channel_id, enabled=False, channel_raw={}):
@ -98,8 +105,8 @@ class ChannelQuery(Query):
"enabled": enabled,
"channel_raw": channel_raw,
}
cls._table.insert(params)
return cls._obj(**params)
super().create(**params)
return cls.obj(**params)
@classmethod
def get(cls, _id):
@ -110,7 +117,7 @@ class ChannelQuery(Query):
@classmethod
def get_by_platform(cls, platform, platform_channel_id):
result = cls._table.find_one(
result = cls.tablename.find_one(
platform=platform, platform_channel_id=platform_channel_id
)
if not result:
@ -118,7 +125,7 @@ class ChannelQuery(Query):
plugins = ChannelPluginQuery.get_from_channel_id(result["id"])
return cls._obj(plugins={plugin.plugin_id: plugin for plugin in plugins}, **result)
return cls.obj(plugins={plugin.plugin_id: plugin for plugin in plugins}, **result)
@classmethod
def delete(cls, _id):
@ -127,12 +134,12 @@ class ChannelQuery(Query):
class ChannelPluginQuery(Query):
_table = db["channel_plugin"]
_obj = ChannelPlugin
tablename = "channel_plugin"
obj = ChannelPlugin
@classmethod
def create(cls, channel_id, plugin_id, enabled=False, config={}):
if cls.exists(channel_id, plugin_id):
if cls.exists(id=channel_id, plugin_id=plugin_id):
raise cls.Duplicated
params = {
@ -141,25 +148,13 @@ class ChannelPluginQuery(Query):
"enabled": enabled,
"config": config,
}
obj_id = cls._table.insert(params)
return cls._obj(id=obj_id, **params)
@classmethod
def get(cls, channel_id, plugin_id):
result = cls._table.find_one(channel_id=channel_id, plugin_id=plugin_id)
if not result:
raise cls.NotFound
return cls._obj(**result)
obj_id = super().create(**params)
return cls.obj(id=obj_id, **params)
@classmethod
def get_from_channel_id(cls, channel_id):
yield from [cls._obj(**row) for row in cls._table.find(channel_id=channel_id)]
@classmethod
def delete(cls, channel_plugin_id):
return cls._table.delete(id=channel_plugin_id)
yield from [cls.obj(**row) for row in cls.tablename.find(channel_id=channel_id)]
@classmethod
def delete_by_channel(cls, channel_id):
cls._table.delete(channel_id=channel_id)
cls.tablename.delete(channel_id=channel_id)

840
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -27,6 +27,8 @@ flake8 = "^3.7.9"
rope = "^0.16.0"
isort = "^4.3.21"
ipdb = "^0.13.2"
pytest = "^6.1.2"
pytest-cov = "^2.10.1"
[tool.poetry.plugins]
[tool.poetry.plugins."butterrobot.plugins"]

109
tests/test_db.py Normal file
View File

@ -0,0 +1,109 @@
import os.path
import tempfile
from dataclasses import dataclass
from unittest import mock
import dataset
import pytest
from butterrobot import db
@dataclass
class DummyItem:
id: int
foo: str
class DummyQuery(db.Query):
tablename = "dummy"
obj = DummyItem
class MockDatabase:
def __init__(self):
self.temp_dir = tempfile.TemporaryDirectory()
def __enter__(self):
db_path = os.path.join(self.temp_dir.name, "db.sqlite")
db.db = dataset.connect(f"sqlite:///{db_path}")
def __exit__(self, exc_type, exc_val, exc_tb):
self.temp_dir.cleanup()
def test_query_create_ok():
with MockDatabase():
assert DummyQuery.create(foo="bar")
def test_query_delete_ok():
with MockDatabase():
item_id = DummyQuery.create(foo="bar")
assert DummyQuery.delete(item_id)
def test_query_exists_by_id_ok():
with MockDatabase():
assert not DummyQuery.exists(id=1)
item_id = DummyQuery.create(foo="bar")
assert DummyQuery.exists(id=item_id)
def test_query_exists_by_attribute_ok():
with MockDatabase():
assert not DummyQuery.exists(id=1)
item_id = DummyQuery.create(foo="bar")
assert DummyQuery.exists(foo="bar")
def test_query_get_ok():
with MockDatabase():
item_id = DummyQuery.create(foo="bar")
item = DummyQuery.get(id=item_id)
assert item.id
def test_query_all_ok():
with MockDatabase():
assert len(list(DummyQuery.all())) == 0
[DummyQuery.create(foo="bar") for i in range(0, 3)]
assert len(list(DummyQuery.all())) == 3
def test_update_ok():
with MockDatabase():
expected = "bar2"
item_id = DummyQuery.create(foo="bar")
assert DummyQuery.update(item_id, foo=expected)
item = DummyQuery.get(id=item_id)
assert item.foo == expected
def test_create_user_sets_password_ok():
password = "password"
with MockDatabase():
user_id = db.UserQuery.create(username="foo", password=password)
user = db.UserQuery.get(id=user_id)
assert user.password == db.UserQuery._hash_password(password)
def test_user_check_credentials_ok():
with MockDatabase():
username = "foo"
password = "bar"
user_id = db.UserQuery.create(username=username, password=password)
user = db.UserQuery.get(id=user_id)
user = db.UserQuery.check_credentials(username, password)
assert isinstance(user, db.UserQuery.obj)
def test_user_check_credentials_ko():
with MockDatabase():
username = "foo"
password = "bar"
user_id = db.UserQuery.create(username=username, password=password)
user = db.UserQuery.get(id=user_id)
assert not db.UserQuery.check_credentials(username, "error")
assert not db.UserQuery.check_credentials("error", password)
assert not db.UserQuery.check_credentials("error", "error")