From 391e16e91b75ed9ddd47e88e8a404e064518d28a Mon Sep 17 00:00:00 2001 From: Hasan Ali Date: Thu, 25 Jun 2020 18:52:34 +0100 Subject: [PATCH 01/11] Separate post from revision in db models --- drp/models/__init__.py | 4 +- drp/models/post.py | 68 ++--- ...separate_posts_into_posts_and_revisions.py | 249 ++++++++++++++++++ 3 files changed, 289 insertions(+), 32 deletions(-) create mode 100644 migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py diff --git a/drp/models/__init__.py b/drp/models/__init__.py index 2cf5cdf..85cc42f 100644 --- a/drp/models/__init__.py +++ b/drp/models/__init__.py @@ -1,8 +1,8 @@ from .question import Question, Site, Subject, Grade -from .post import Post, Tag, File, Post_Tag +from .post import Post, PostRevision, Tag, PostRev_Tag, File from .device import Device from .user import User, UserRole -__all__ = ["Post", "Tag", "File", "Post_Tag", "Question", +__all__ = ["Post", "PostRevision", "Tag", "PostRev_Tag", "File", "Site", "Subject", "Grade", "Device", "User", "UserRole"] diff --git a/drp/models/post.py b/drp/models/post.py index 38552a3..003c2f7 100644 --- a/drp/models/post.py +++ b/drp/models/post.py @@ -1,4 +1,3 @@ -from sqlalchemy.schema import Sequence from sqlalchemy.sql import func from sqlalchemy.sql.expression import cast from sqlalchemy.orm import relationship @@ -20,23 +19,39 @@ def create_tsvector(*components): class Post(db.Model): __tablename__ = "posts" - post_id_seq = Sequence('post_id_seq', metadata=db.Model.metadata) + id = db.Column(db.Integer, primary_key=True) + + is_guideline = db.Column(db.Boolean()) + + latest_rev_id = db.Column(db.Integer, db.ForeignKey( + "post_revisions.id", name="posts_latest_rev_id_fkey", + ondelete="SET NULL")) + + latest_rev = relationship("PostRevision", + back_populates="post", + foreign_keys=[latest_rev_id]) + + resolves = relationship("Question", back_populates="resolved_by") + + +class PostRevision(db.Model): + __tablename__ = "post_revisions" id = db.Column(db.Integer, primary_key=True) - title = db.Column(db.String(120), nullable=False) - summary = db.Column(db.String(200)) + post_id = db.Column(db.Integer, db.ForeignKey( + "posts.id", name="post_revisions_post_id_fkey", ondelete="SET NULL")) + + title = db.Column(db.Text(), nullable=False) + summary = db.Column(db.Text()) content = db.Column(db.Text()) - is_guideline = db.Column(db.Boolean()) - is_current = db.Column(db.Boolean(), default=True) - post_id = db.Column(db.Integer, server_default=post_id_seq.next_value()) created_at = db.Column(db.DateTime(timezone=True), nullable=False, server_default=func.now()) - tags = relationship("Tag", secondary="post_tag") + post = relationship("Post", foreign_keys=[post_id]) - files = relationship("File", back_populates="post") - resolves = relationship("Question", back_populates="resolved_by") + tags = relationship("Tag", secondary="post_rev_tag") + files = relationship("File", back_populates="post_revision") __ts_vector__ = create_tsvector( title, @@ -46,20 +61,12 @@ class Post(db.Model): __table_args__ = ( db.Index( - 'idx_post_fulltextsearch', + 'idx_post_revision_fulltextsearch', __ts_vector__, postgresql_using='gin' ), - db.Index( - 'unique_current_post_id', post_id, is_current, - unique=True, - postgresql_where=is_current - ), ) - def __repr__(self): - return f"" - class Tag(db.Model): __tablename__ = "tags" @@ -67,28 +74,28 @@ class Tag(db.Model): id = db.Column(db.Integer, primary_key=True) name = db.Column(db.Text, unique=True) - posts = relationship("Post", secondary="post_tag") + post_revisions = relationship("PostRevision", secondary="post_rev_tag") def __repr__(self): return f"" -class Post_Tag(db.Model): - __tablename__ = "post_tag" +class PostRev_Tag(db.Model): + __tablename__ = "post_rev_tag" - post_id = db.Column(db.Integer, - db.ForeignKey("posts.id"), - primary_key=True) + revision_id = db.Column(db.Integer, + db.ForeignKey("post_revisions.id"), + primary_key=True) tag_id = db.Column(db.Integer, db.ForeignKey("tags.id"), primary_key=True) - post = relationship(Post) + revision = relationship(PostRevision) tag = relationship(Tag) def __repr__(self): - return f" {self.tag}>" + return f" {self.tag}>" class File(db.Model): @@ -98,9 +105,10 @@ class File(db.Model): name = db.Column(db.String(200)) filename = db.Column(db.String(300)) - post_id = db.Column(db.Integer, - db.ForeignKey("posts.id")) - post = relationship('Post', back_populates='files') + post_rev_id = db.Column(db.Integer, + db.ForeignKey("post_revisions.id")) + + post_revision = relationship('PostRevision', back_populates='files') def __repr__(self): return f"" diff --git a/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py b/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py new file mode 100644 index 0000000..d283cfb --- /dev/null +++ b/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py @@ -0,0 +1,249 @@ +"""Separate posts into posts and revisions + +Revision ID: ef0a345b0640 +Revises: 63246e4d9192 +Create Date: 2020-06-25 15:37:09.352559 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy import orm +from sqlalchemy.ext.declarative import declarative_base + +# revision identifiers, used by Alembic. +revision = 'ef0a345b0640' +down_revision = '63246e4d9192' +branch_labels = None +depends_on = None + +Base = declarative_base() + + +class Post_old(Base): + __tablename__ = "posts" + + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String(120)) + summary = sa.Column(sa.String(200)) + content = sa.Column(sa.Text()) + is_guideline = sa.Column(sa.Boolean()) + is_current = sa.Column(sa.Boolean()) + post_id = sa.Column(sa.Integer) + created_at = sa.Column(sa.DateTime(timezone=True)) + + +class Post_new_1(Base): + __tablename__ = "posts_new" + + id = sa.Column(sa.Integer(), primary_key=True) + is_guideline = sa.Column(sa.Boolean()) + latest_rev_id = sa.Column(sa.Integer()) + + +class PostRevision(Base): + __tablename__ = "post_revisions" + + id = sa.Column(sa.Integer, primary_key=True) + post_id = sa.Column(sa.Integer, sa.ForeignKey("posts_new.id")) + + title = sa.Column(sa.Text(), nullable=False) + summary = sa.Column(sa.Text()) + content = sa.Column(sa.Text()) + + created_at = sa.Column(sa.DateTime(timezone=True), nullable=False) + + +class Post_Tag_old(Base): + __tablename__ = "post_tag" + + post_id = sa.Column(sa.Integer, + sa.ForeignKey("posts.id"), + primary_key=True) + + tag_id = sa.Column(sa.Integer, + sa.ForeignKey("tags.id"), + primary_key=True) + + +class PostRev_Tag(Base): + __tablename__ = "post_rev_tag" + + revision_id = sa.Column(sa.Integer, + sa.ForeignKey("post_revisions.id"), + primary_key=True) + + tag_id = sa.Column(sa.Integer, + sa.ForeignKey("tags.id"), + primary_key=True) + + +class Tag(Base): + __tablename__ = "tags" + + id = sa.Column(sa.Integer, primary_key=True) + + +class File_new(Base): + __tablename__ = "files" + + id = sa.Column(sa.Integer, primary_key=True) + + post_id = sa.Column(sa.Integer, + sa.ForeignKey("posts.id")) + + post_rev_id = sa.Column(sa.Integer, + sa.ForeignKey("post_revisions.id")) + + +class Question(Base): + __tablename__ = "questions" + + id = sa.Column(sa.Integer, primary_key=True) + + post_id = sa.Column(sa.Integer) + post_id_new = sa.Column(sa.Integer, + sa.ForeignKey("posts_new.id")) + + +def upgrade(): + bind = op.get_bind() + session = orm.Session(bind=bind) + + # Create the new posts table + op.create_table("posts_new", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("is_guideline", sa.Boolean(), nullable=True), + sa.Column("latest_rev_id", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id") + ) + + # Create the post revisions table + op.create_table('post_revisions', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('post_id', sa.Integer(), nullable=True), + sa.Column('title', sa.Text(), nullable=False), + sa.Column('summary', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), + server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint( + ['post_id'], ['posts_new.id'], ondelete="SET NULL"), + sa.PrimaryKeyConstraint('id') + ) + + # Create linking table for revisions <-> tags + op.create_table('post_rev_tag', + sa.Column('revision_id', sa.Integer(), nullable=False), + sa.Column('tag_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ['revision_id'], ['post_revisions.id'], ), + sa.ForeignKeyConstraint(['tag_id'], ['tags.id'], ), + sa.PrimaryKeyConstraint('revision_id', 'tag_id') + ) + + op.add_column('files', sa.Column( + 'post_rev_id', sa.Integer(), nullable=True)) + op.drop_constraint('files_post_id_fkey', 'files', type_='foreignkey') + + op.add_column("questions", sa.Column( + "post_id_new", sa.Integer(), nullable=True)) + op.drop_constraint("questions_post_id_fkey", + "questions", type_="foreignkey") + + # Copy existing posts to the new tables + for post in session.query(Post_old).filter(Post_old.is_current).all(): + # Create post entry for the latest version of this post + new_post = Post_new_1(is_guideline=post.is_guideline) + session.add(new_post) + session.commit() + # Create revision entries for each version of this post + for p in session.query(Post_old).filter( + Post_old.post_id == post.post_id).all(): + revision = PostRevision(post_id=new_post.id, title=p.title, + summary=p.summary, content=p.content, + created_at=p.created_at) + session.add(revision) + session.commit() + # Copy over all the tags + for post_tag in session.query(Post_Tag_old).filter( + Post_Tag_old.post_id == p.id): + rev_tag = PostRev_Tag(revision_id=revision.id, + tag_id=post_tag.tag_id) + session.add(rev_tag) + session.commit() + # Update references from files + for file in session.query(File_new).filter( + File_new.post_id == p.id): + file.post_rev_id = revision.id + session.commit() + # Link latest version back to the owning post + if p.is_current: + new_post.latest_rev_id = revision.id + session.commit() + # Update references from questions + for q in session.query(Question).filter(Question.post_id == post.id): + q.post_id_new = new_post.id + session.commit() + + op.drop_table("post_tag") + op.drop_column("files", "post_id") + op.drop_column("questions", "post_id") + op.alter_column("questions", "post_id_new", new_column_name="post_id") + + op.drop_table("posts") + op.rename_table("posts_new", "posts") + + op.create_foreign_key(None, 'posts', 'post_revisions', + ['latest_rev_id'], ['id'], ondelete="SET NULL") + + op.create_foreign_key(None, 'files', + 'post_revisions', ['post_rev_id'], ['id']) + + op.create_foreign_key("post_revisions_post_id_fkey", 'questions', + 'posts', ['post_id'], ['id']) + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('posts', sa.Column('created_at', postgresql.TIMESTAMP( + timezone=True), server_default=sa.text('now()'), autoincrement=False, + nullable=False)) + op.add_column('posts', sa.Column('title', sa.VARCHAR( + length=120), autoincrement=False, nullable=False)) + op.add_column('posts', sa.Column('summary', sa.VARCHAR( + length=200), autoincrement=False, nullable=True)) + op.add_column('posts', sa.Column('is_current', sa.BOOLEAN(), + autoincrement=False, nullable=True)) + op.add_column('posts', sa.Column('content', sa.TEXT(), + autoincrement=False, nullable=True)) + op.add_column('posts', sa.Column('post_id', sa.INTEGER(), + server_default=sa.text( + "nextval('post_id_seq'::regclass)"), autoincrement=True, + nullable=True)) + op.drop_constraint('posts_latest_rev_id_fkey', 'posts', type_='foreignkey') + op.create_index('unique_current_post_id', 'posts', [ + 'post_id', 'is_current'], unique=True) + op.drop_column('posts', 'latest_rev_id') + op.add_column('files', sa.Column('post_id', sa.INTEGER(), + autoincrement=False, nullable=True)) + op.drop_constraint(None, 'files', type_='foreignkey') + op.create_foreign_key('files_post_id_fkey', 'files', + 'posts', ['post_id'], ['id']) + op.drop_column('files', 'post_rev_id') + op.create_table('post_tag', + sa.Column('post_id', sa.INTEGER(), + autoincrement=False, nullable=False), + sa.Column('tag_id', sa.INTEGER(), + autoincrement=False, nullable=False), + sa.ForeignKeyConstraint( + ['post_id'], ['posts.id'], + name='post_tag_post_id_fkey'), + sa.ForeignKeyConstraint(['tag_id'], ['tags.id'], + name='post_tag_tag_id_fkey'), + sa.PrimaryKeyConstraint( + 'post_id', 'tag_id', name='post_tag_pkey') + ) + op.drop_table('post_rev_tag') + op.drop_table('post_revisions') + # ### end Alembic commands ### From c579be2ea97e642df2898944d7d02150f8501355 Mon Sep 17 00:00:00 2001 From: Hasan Ali Date: Thu, 25 Jun 2020 19:24:04 +0100 Subject: [PATCH 02/11] Add delete cascade for posts, make fields non-nullable --- drp/models/__init__.py | 2 +- drp/models/post.py | 9 ++-- ...separate_posts_into_posts_and_revisions.py | 51 ++----------------- 3 files changed, 10 insertions(+), 52 deletions(-) diff --git a/drp/models/__init__.py b/drp/models/__init__.py index 85cc42f..bc610b2 100644 --- a/drp/models/__init__.py +++ b/drp/models/__init__.py @@ -4,5 +4,5 @@ from .user import User, UserRole __all__ = ["Post", "PostRevision", "Tag", "PostRev_Tag", "File", - "Site", "Subject", "Grade", "Device", + "Question", "Site", "Subject", "Grade", "Device", "User", "UserRole"] diff --git a/drp/models/post.py b/drp/models/post.py index 003c2f7..823a488 100644 --- a/drp/models/post.py +++ b/drp/models/post.py @@ -24,8 +24,7 @@ class Post(db.Model): is_guideline = db.Column(db.Boolean()) latest_rev_id = db.Column(db.Integer, db.ForeignKey( - "post_revisions.id", name="posts_latest_rev_id_fkey", - ondelete="SET NULL")) + "post_revisions.id", name="posts_latest_rev_id_fkey")) latest_rev = relationship("PostRevision", back_populates="post", @@ -39,11 +38,11 @@ class PostRevision(db.Model): id = db.Column(db.Integer, primary_key=True) post_id = db.Column(db.Integer, db.ForeignKey( - "posts.id", name="post_revisions_post_id_fkey", ondelete="SET NULL")) + "posts.id", name="post_revisions_post_id_fkey", ondelete="CASCADE")) title = db.Column(db.Text(), nullable=False) - summary = db.Column(db.Text()) - content = db.Column(db.Text()) + summary = db.Column(db.Text(), nullable=False) + content = db.Column(db.Text(), nullable=False) created_at = db.Column(db.DateTime(timezone=True), nullable=False, server_default=func.now()) diff --git a/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py b/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py index d283cfb..10394ed 100644 --- a/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py +++ b/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py @@ -123,12 +123,12 @@ def upgrade(): sa.Column('id', sa.Integer(), nullable=False), sa.Column('post_id', sa.Integer(), nullable=True), sa.Column('title', sa.Text(), nullable=False), - sa.Column('summary', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), + sa.Column('summary', sa.Text(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), sa.ForeignKeyConstraint( - ['post_id'], ['posts_new.id'], ondelete="SET NULL"), + ['post_id'], ['posts_new.id'], ondelete="CASCADE"), sa.PrimaryKeyConstraint('id') ) @@ -195,7 +195,7 @@ def upgrade(): op.rename_table("posts_new", "posts") op.create_foreign_key(None, 'posts', 'post_revisions', - ['latest_rev_id'], ['id'], ondelete="SET NULL") + ['latest_rev_id'], ['id']) op.create_foreign_key(None, 'files', 'post_revisions', ['post_rev_id'], ['id']) @@ -205,45 +205,4 @@ def upgrade(): def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.add_column('posts', sa.Column('created_at', postgresql.TIMESTAMP( - timezone=True), server_default=sa.text('now()'), autoincrement=False, - nullable=False)) - op.add_column('posts', sa.Column('title', sa.VARCHAR( - length=120), autoincrement=False, nullable=False)) - op.add_column('posts', sa.Column('summary', sa.VARCHAR( - length=200), autoincrement=False, nullable=True)) - op.add_column('posts', sa.Column('is_current', sa.BOOLEAN(), - autoincrement=False, nullable=True)) - op.add_column('posts', sa.Column('content', sa.TEXT(), - autoincrement=False, nullable=True)) - op.add_column('posts', sa.Column('post_id', sa.INTEGER(), - server_default=sa.text( - "nextval('post_id_seq'::regclass)"), autoincrement=True, - nullable=True)) - op.drop_constraint('posts_latest_rev_id_fkey', 'posts', type_='foreignkey') - op.create_index('unique_current_post_id', 'posts', [ - 'post_id', 'is_current'], unique=True) - op.drop_column('posts', 'latest_rev_id') - op.add_column('files', sa.Column('post_id', sa.INTEGER(), - autoincrement=False, nullable=True)) - op.drop_constraint(None, 'files', type_='foreignkey') - op.create_foreign_key('files_post_id_fkey', 'files', - 'posts', ['post_id'], ['id']) - op.drop_column('files', 'post_rev_id') - op.create_table('post_tag', - sa.Column('post_id', sa.INTEGER(), - autoincrement=False, nullable=False), - sa.Column('tag_id', sa.INTEGER(), - autoincrement=False, nullable=False), - sa.ForeignKeyConstraint( - ['post_id'], ['posts.id'], - name='post_tag_post_id_fkey'), - sa.ForeignKeyConstraint(['tag_id'], ['tags.id'], - name='post_tag_tag_id_fkey'), - sa.PrimaryKeyConstraint( - 'post_id', 'tag_id', name='post_tag_pkey') - ) - op.drop_table('post_rev_tag') - op.drop_table('post_revisions') - # ### end Alembic commands ### + raise Exception("Not implemented :(") From f3303e7922ea765ace7c0aa8c0ad429a09562af4 Mon Sep 17 00:00:00 2001 From: Hasan Ali Date: Fri, 26 Jun 2020 16:24:15 +0100 Subject: [PATCH 03/11] Migrate posts api to reflect new database structure --- drp/__init__.py | 8 +- drp/api/__init__.py | 7 +- drp/api/auth.py | 24 +- drp/api/files.py | 2 +- drp/api/notifications.py | 6 +- drp/api/posts.py | 901 ++++++++---------- drp/api/search.py | 62 +- drp/api/users.py | 26 +- drp/api/utils.py | 25 +- drp/models/post.py | 19 +- ...separate_posts_into_posts_and_revisions.py | 7 +- 11 files changed, 492 insertions(+), 595 deletions(-) diff --git a/drp/__init__.py b/drp/__init__.py index 3c7805b..28971b2 100644 --- a/drp/__init__.py +++ b/drp/__init__.py @@ -18,13 +18,6 @@ def init_cli(app): def init_api(app): api = Api(app) - api.add_resource(res.PostResource, "/api/posts/") - api.add_resource(res.PostListResource, "/api/posts") - - api.add_resource(res.RevisionResource, "/api/revisions/") - - api.add_resource(res.PostFetchResource, "/api/fetch/posts/") - api.add_resource(res.PostSearchResource, "/api/search/posts/") @@ -47,6 +40,7 @@ def init_api(app): api.add_resource(res.SiteResource, "/api/sites/") api.add_resource(res.SiteListResource, "/api/sites") + app.register_blueprint(res.posts, url_prefix="/api/posts") app.register_blueprint(res.questions, url_prefix="/api/questions") app.register_blueprint(res.notifications, url_prefix="/api/notifications") app.register_blueprint(res.users, url_prefix="/api/users") diff --git a/drp/api/__init__.py b/drp/api/__init__.py index 03b45c1..2688478 100644 --- a/drp/api/__init__.py +++ b/drp/api/__init__.py @@ -1,6 +1,5 @@ from .auth import auth -from .posts import (PostResource, PostListResource, RevisionResource, - PostFetchResource) +from .posts import posts from .search import PostSearchResource from .tags import TagListResource, TagResource from .files import (FileResource, FileListResource, RawFileViewResource, @@ -11,9 +10,7 @@ from .notifications import notifications from .users import users -__all__ = ["PostResource", "PostListResource", - "RevisionResource", "PostFetchResource", - "PostSearchResource", +__all__ = ["posts", "PostSearchResource", "TagResource", "TagListResource", "QuestionResource", "QuestionListResource", "FileResource", "FileListResource", diff --git a/drp/api/auth.py b/drp/api/auth.py index 9722cec..c2b9f32 100644 --- a/drp/api/auth.py +++ b/drp/api/auth.py @@ -16,7 +16,7 @@ from ..swag import swag from .users import serialize_role -from .utils import error +from .utils import abort auth = Blueprint("auth", __name__) @@ -75,21 +75,19 @@ def authenticate(): password = body.get("password") if email is None or password is None: - return error(400, - message="`email` and `password` fields" - "are required.") + abort(400, message="`email` and `password` fields are required.") user = User.query.filter(User.email == email).one_or_none() if user is None: - return error(401, type="InvalidCredentials") + abort(401, type="InvalidCredentials") hasher = PasswordHasher() try: hasher.verify(user.password_hash, password) except VerifyMismatchError: - return error(401, type="InvalidCredentials") + abort(401, type="InvalidCredentials") if hasher.check_needs_rehash(user.password_hash): hash = hasher.hash(password) @@ -97,7 +95,7 @@ def authenticate(): db.session.commit() if not user.confirmed: - return error(401, type="Unconfirmed") + abort(401, type="Unconfirmed") now = time.time() expiration_time = now + 2 * 60 * 60 @@ -125,27 +123,27 @@ def register(): email = body.get("email") if not email: - return error(400, message="`email` is required") + abort(400, message="`email` is required") password = body.get("password") if not password: - return error(400, message="`password` is required") + abort(400, message="`password` is required") email = email.lower() parts = email.split("@") if len(parts) != 2: - return error(400, type="InvalidEmail") + abort(400, type="InvalidEmail") if len(password) < 8: - return error(400, type="ShortPassword") + abort(400, type="ShortPassword") domain = parts[1] if domain != "nhs.net" and domain != "ic.ac.uk" \ and domain != "imperial.ac.uk": - return error(400, type="UnauthorisedDomain") + abort(400, type="UnauthorisedDomain") if User.query.filter(User.email == email).one_or_none() is not None: - return error(400, type="Registered") + abort(400, type="Registered") hasher = PasswordHasher() hash = hasher.hash(password) diff --git a/drp/api/files.py b/drp/api/files.py index 5ce1024..1e86909 100644 --- a/drp/api/files.py +++ b/drp/api/files.py @@ -30,7 +30,7 @@ def serialize_file(file): return { "id": file.id, "name": file.name, - "post": file.post_id + "post_revision": file.post_rev_id } diff --git a/drp/api/notifications.py b/drp/api/notifications.py index 50ddedc..871b63e 100644 --- a/drp/api/notifications.py +++ b/drp/api/notifications.py @@ -4,7 +4,7 @@ from ..db import db from ..models import Device, User -from .utils import error +from .utils import abort notifications = Blueprint("notifications", __name__) @@ -15,11 +15,11 @@ def register(): user = request.args.get("user") if not token: - return error(400, "Missing `token` query parameter") + abort(400, "Missing `token` query parameter") user = User.query.filter(User.id == user).one_or_none() if not user: - return error(400, "Missing `user` query parameter") + abort(400, "Missing `user` query parameter") device = Device(expo_push_token=token, user=user) diff --git a/drp/api/posts.py b/drp/api/posts.py index 2491989..a49295f 100644 --- a/drp/api/posts.py +++ b/drp/api/posts.py @@ -1,528 +1,427 @@ import pytz - import os import werkzeug import secrets from datetime import datetime -from flask import request, current_app -from flask_restful import Resource, abort - -from ..db import db -from ..models import Post, Post_Tag, Tag, File, Question -from ..swag import swag +from flask import Blueprint, request, jsonify, current_app +from sqlalchemy.orm import joinedload from .. import notifications -from .tags import serialize_tag -from .files import serialize_file, allowed_file - - -def delete_post(post): - for file in post.files: - try: - os.remove(os.path.join( - current_app.config['UPLOAD_FOLDER'], file.filename)) - except OSError as e: - print("Could not delete file, " + repr(e)) +from ..models import Post, PostRevision, Tag, PostRev_Tag, File, Question +from ..db import db - db.session.delete(file) +from .utils import abort +from .files import allowed_file - db.session.delete(post) +posts = Blueprint("posts", __name__) -def migrate_resolved_questions(questions, revision): - for question in questions: - question.resolved_by = revision +def serialize_tag(tag): + return { + "id": tag.id, + "name": tag.name + } -def unresolve_all(questions): - for question in questions: - question.resolved = False +def serialize_file(f): + return { + "id": f.id, + "name": f.name, + "filename": f.filename + } -def get_current_post_by_id(id): - return Post.query.filter(Post.is_current & (Post.post_id == id)) \ - .one_or_none() +def serialize_revision(rev): + return { + "id": rev.id, + "title": rev.title, + "summary": rev.summary, + "content": rev.content, + "created_at": rev.created_at.astimezone(pytz.utc).isoformat(), + "tags": [serialize_tag(tag) for tag in rev.tags], + "files": [serialize_file(f) for f in rev.files] + } -@swag.definition("Post") def serialize_post(post): - """ - Represents a post revision. - --- - properties: - id: - type: integer - title: - type: string - summary: - type: string - content: - type: string - created_at: - type: string - tags: - type: array - items: - $ref: "#/definitions/Tag" - files: - type: array - items: - $ref: "#/definitions/File" - is_guideline: - type: boolean - is_current: - type: boolean - revision_id: - type: integer - """ return { - "id": post.post_id, - "title": post.title, - "summary": post.summary, - "content": post.content, - "is_guideline": post.is_guideline, - "is_current": post.is_current, - "revision_id": post.id, - "created_at": post.created_at.astimezone(pytz.utc).isoformat(), - "tags": [serialize_tag(tag) for tag in post.tags], - "files": [serialize_file(file) for file in post.files] + "id": post.id, + "type": "guideline" if post.is_guideline else "update", + "latest_revision": serialize_revision(post.latest_rev) } -class PostResource(Resource): - - def get(self, id): - """ - Gets a single post by id. - --- - parameters: - - name: id - in: path - type: integer - required: true - - name: include_old - in: query - type: boolean - required: false - - name: reverse - in: query - type: boolean - required: false - responses: - 200: - schema: - type: array - items: - $ref: "#/definitions/Post" - 404: - description: Not found - """ - include_old = request.args.get("include_old") - reverse = request.args.get("reverse") - - query = Post.query.filter(Post.post_id == id) - - if include_old != "true": - query = query.filter(Post.is_current) - - if reverse == "true": - query = query.order_by(Post.id.desc()) - else: - query = query.order_by(Post.id) - - posts = query.all() - - if len(posts) == 0: - return abort(404) - - return [serialize_post(post) for post in posts] - - def delete(self, id): - """ - Deletes all revisions of a post by ID. - --- - parameters: - - name: id - in: path - type: integer - required: true - responses: - 204: - description: Success - 404: - description: Not found - """ - revisions = Post.query.filter(Post.post_id == id).all() - - if len(revisions) == 0: - return abort(404) - - for revision in revisions: - delete_post(revision) - # Mark associated questions as unresolved - unresolve_all(revision.resolves) - - db.session.commit() - - return "", 204 - - -class PostListResource(Resource): - - def get(self): - """ - Gets a list of all posts. - --- - parameters: - - name: guidelines_only - in: query - type: boolean - required: false - - name: include_old - in: query - type: boolean - required: false - - name: tag - in: query - type: string - required: false - - name: page - in: query - type: integer - - name: per_page - in: query - type: integer - responses: - 200: - schema: - type: array - items: - $ref: "#/definitions/Post" - - """ - guidelines_only = request.args.get("guidelines_only") - include_old = request.args.get("include_old") - tag = request.args.get("tag") - - page = request.args.get("page") - if page is None: +@posts.route("/", methods=["GET", "POST"]) +def all_posts(): + if request.method == "GET": + return get_posts() + + elif request.method == "POST": + return create_post() + + +def get_posts(): + ids = request.args.getlist("ids") + type = request.args.get("type") + tag = request.args.get("tag") + page = request.args.get("page") + per_page = request.args.get("per_page") + + query = Post.query.join(Post.latest_rev).options( + joinedload("latest_rev").options( + joinedload("tags"), + joinedload("files"))) + + # Filter by id + if len(ids) == 1 and ',' in ids[0]: + ids = ids[0].split(',') + + if len(ids) == 1 and ids[0] == "": + return jsonify([]) + + if not all(id.isdigit() for id in ids): + abort(400, message="IDs must be integers") + + if len(ids) > 0: + query = query.filter(Post.id.in_(ids)) + + # Filter by type + if type == "update": + query = query.filter(Post.is_guideline == False) # noqa + + elif type == "guideline": + query = query.filter(Post.is_guideline == True) # noqa + + elif type and type != "any": + abort(400, "type must be one of `any`, `update` or `guideline`") + + # Filter by tag name + if tag: + tag = Tag.query.filter(Tag.name == tag).one_or_none() + if not tag: + return jsonify([]) + query = query.join(PostRev_Tag).join( + Tag).filter(Tag.id == tag.id) + + query = query.join(Post.latest_rev).order_by( + PostRevision.created_at.desc()) + + # Pagination + if per_page: + if not page: page = 0 else: page = int(page) - per_page = request.args.get("per_page") - if per_page is not None: - per_page = int(per_page) - - query = Post.query - - if guidelines_only == "true": - query = query.filter(Post.is_guideline) - if include_old != "true": - query = query.filter(Post.is_current) - if tag is not None: - query = query.join(Post_Tag).join(Tag).filter(Tag.name == tag) - - query = query.order_by(Post.created_at.desc()) - - if per_page is not None: - query = query.limit(per_page).offset(page * per_page) - - return [serialize_post(post) for post in query] - - def post(self): - """ - Creates a new post. - --- - parameters: - - in: formData - name: title - type: string - required: true - maxLength: 120 - description: The title of the post. - - in: formData - name: summary - type: string - required: true - maxLength: 200 - description: A short summary of the post. - - in: formData - name: content - type: string - required: true - - in: formData - name: tags - required: false - type: array - description: The tags associated with the post. - items: - type: string - - in: formData - name: files - type: array - description: The files attached to the post. - items: - type: file - - in: formData - type: array - name: names - description: The names of files attached to the post. - items: - type: string - - in: formData - type: boolean - name: is_guideline - description: Indicates whether a post is a guideline. - - in: formData - type: integer - name: updates - description: ID of the post that is to be updated. - - in: formData - type: array - name: resolves - description: The IDs of questions that are resolved by this post. - items: - type: number - responses: - 200: - schema: - $ref: "#/definitions/Post" - """ - title = request.form.get('title') - summary = request.form.get('summary') - content = request.form.get('content') - tag_names = request.form.getlist('tags') - files = request.files.getlist('files') - names = request.form.getlist('names') - is_guideline = request.form.get('is_guideline') - updates = request.form.get('updates') - resolves = request.form.getlist('resolves') - - # Check that required fields are present - if title is None or summary is None or content is None: - return abort(400, - message="`title`, `summary` and `content` \ - fields are required.") - - if title == "": - return abort(400, message="`title` field cannot be empty.") - - # Check that no field exceeds permitted length - def error_message(name, count): - return f"`{name}` must not be more than {count} characters." - - if len(title) > 120: - return abort(400, message=error_message("title", 120)) - - if summary is not None and len(summary) > 200: - return abort(400, message=error_message("summary", 200)) - - # Check that tags are valid - tags = None - - if len(tag_names) != 0: - tags = Tag.query.filter(Tag.name.in_(tag_names)) - if tags.count() < len(tag_names): - return abort(400, message="Invalid tags - all tags must be" - " predefined through the tags api.") - - tags = tags.all() if tags is not None else [] - - # Check that files and the associated names are valid - if len(files) != len(names): - return abort(400, message="The number of files must match " - "the number of supplied names.") - - for name in names: - if len(name) > 200: - return abort(400, message=error_message("file name", 200)) - - if not allowed_file(name, - current_app.config["ALLOWED_FILE_EXTENSIONS"]): - return abort(400, message=f"The file extension of {name} is " - "not allowed for security reasons. If " - "you believe that this file type is safe " - "to upload, contact the developer.") - - # Check that the fields for posting guidelines are valid - if is_guideline is not None and is_guideline != "false" \ - and is_guideline != "true": - return abort(400, message="The value is_guideline=" - f"{is_guideline} is invalid.") - - # Check that all resolved questions exist and are not resolved already - resolved_questions = [] - if resolves is not None: - if len(resolves) == 1 and ',' in resolves[0]: - resolves = resolves[0].split(',') - for question_id in resolves: - question = Question.query.filter( - Question.id == question_id).one_or_none() - if question is None: - abort(400, message="One of the resolved questions does " - "not exist") - if question.resolved: - abort( - 400, message="Cannot resolve question that is already " - "resolved.") - resolved_questions.append(question) - - # Add post to the database - if is_guideline == "true" and updates is not None: - old_post = get_current_post_by_id(updates) - if old_post is None or not old_post.is_guideline: - return abort(400, message="Invalid updated post ID.") - post = Post(title=title, summary=summary, content=content, - is_guideline=True, post_id=updates, tags=tags) - old_post.is_current = False - migrate_resolved_questions(old_post.resolves, post) - else: - post = Post(title=title, summary=summary, content=content, - is_guideline=(is_guideline == "true"), tags=tags) - - # Link resolved questions to the post - if len(resolved_questions) > 0: - post.resolves = resolved_questions - for question in resolved_questions: - question.resolved = True - db.session.add(post) - - # Save files - for i in range(0, len(files)): - # Prefix file name with current time and random number to allow - # files with the same name - filename = datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f_') + \ - str(secrets.randbelow(10000000000)) + "_" + \ - werkzeug.utils.secure_filename(names[i]) - - path = os.path.join( - current_app.config['UPLOAD_FOLDER'], filename) - if (os.path.isfile(path)): - return abort(422, message="An unexpected file collision " - "ocurred. This error was thought to be " - "impossible to arise in practice. Please " - "contact the developer quoting this error " - "message.") - files[i].save(path) - - file = File(name=names[i], filename=filename, post=post) - - db.session.add(file) - - db.session.commit() - - if len(resolved_questions) > 0: - for q in resolved_questions: - notifications.send_user( - q.user, "Your question has been resolved", - q.text, data={"id": post.id, "resolves": q.id}) - - notifications.broadcast(title, summary, data={"id": post.post_id}) - - return serialize_post(post) - - -class RevisionResource(Resource): - - def get(self, id): - """ - Gets a single revision by ID. - --- - parameters: - - name: id - in: path - type: integer - required: true - responses: - 200: - schema: - $ref: "#/definitions/Post" - 404: - description: Not found - """ - revision = Post.query.filter(Post.id == id).one_or_none() - - if revision is None: - return abort(404) - - return serialize_post(revision) - - def delete(self, id): - """ - Deletes a revision of a post by ID. - --- - parameters: - - name: id - in: path - type: integer - required: true - responses: - 204: - description: Success - 404: - description: Not found - """ - revision = Post.query.filter(Post.id == id).one_or_none() - - if revision is None: - return abort(404) - - delete_post(revision) - - if revision.is_current: - revisions = Post.query.filter( - Post.post_id == revision.post_id) \ - .order_by(Post.id.desc()).all() - if len(revisions) > 0: - newest = revisions[0] - newest.is_current = True - - # Make resolved questions point to the new current revison - migrate_resolved_questions(revision.resolves, newest) - else: - unresolve_all(revision.resolves) - - db.session.commit() - - return "", 204 - - -class PostFetchResource(Resource): - - def get(self): - """ - Returns a list of posts identified by the supplied IDs. - --- - parameters: - - name: ids - in: query - type: array - items: - type: number - required: true - responses: - 200: - schema: - type: array - items: - $ref: "#/definitions/Post" - 404: - description: Not found - """ - ids = request.args.getlist("ids") - - if len(ids) == 1 and ',' in ids[0]: - ids = ids[0].split(',') - - if len(ids) == 1 and ids[0] == "": - return [] - - if not all(id.isdigit() for id in ids): - abort(400, message="IDs must be integers") - - posts = Post.query.filter( - Post.is_current & Post.post_id.in_(ids)).all() - - return [serialize_post(post) for post in posts] + per_page = int(per_page) + query = query.limit(per_page).offset(page * per_page) + + return jsonify([serialize_post(post) for post in query]) + + +def create_post(): + # return serialize_post(post) + title = request.form.get("title") + summary = request.form.get("summary") + content = request.form.get("content") + tag_names = request.form.getlist("tags") + type = request.form.get("type") + resolves = request.form.getlist("resolves") + files = request.form.getlist("files") + names = request.form.getlist("file_names") + + validate_rev_data(title, summary, content, files, names) + + tags = validate_rev_tags(tag_names) + resolved_questions = validate_rev_resolved_questions(resolves) + + # # Check that type is valid + if type is not None and type != "update" and type != "guideline": + abort(400, message="type must be one of `update` \ + or `guideline`") + + post = Post(is_guideline=type == "guideline") + db.session.add(post) + db.session.commit() + + revision = create_rev(title, summary, content, post, tags) + + link_questions_to_post(resolved_questions, post) + save_files_for_revision(files, names, revision) + + db.session.commit() + + notify_resolved_questions(resolved_questions, post) + notify_new_post_rev(revision) + + return jsonify(serialize_post(post)) + + +@posts.route("/", methods=["GET", "DELETE"]) +def single_post(id): + if request.method == "GET": + return get_post(id) + + elif request.method == "DELETE": + return delete_post(id) + + +def get_post(id): + post = Post.query.filter(Post.id == id).one_or_none() + + if post is None: + abort(404) + + return jsonify(serialize_post(post)) + + +def delete_post(id): + for question in Question.query.filter(Question.post_id == id): + question.resolved_by = None + question.resolved = False + + rows_deleted = Post.query.filter(Post.id == id).delete() + + if rows_deleted == 0: + abort(404) + + db.session.commit() + + return jsonify({"message": "deleted"}), 204 + + +@posts.route("//revisions", methods=["GET", "POST"]) +def post_revisions(id): + post = Post.query.filter(Post.id == id).one_or_none() + + if post is None: + abort(404) + + if request.method == "GET": + return get_post_revisions(post) + + elif request.method == "POST": + return create_post_revision(post) + + +def get_post_revisions(post): + order = request.args.get("order") + query = PostRevision.query.filter(PostRevision.post_id == post.id) + + if not order or order == "desc": + query = query.order_by(PostRevision.created_at.desc()) + + elif order == "asc": + query = query.order_by(PostRevision.created_at.asc()) + + else: + abort(400, message="order must be one of `asc` or `desc`") + + return jsonify([serialize_revision(rev) for rev in query]) + + +def create_post_revision(post): + title = request.form.get("title") + summary = request.form.get("summary") + content = request.form.get("content") + tag_names = request.form.getlist("tags") + resolves = request.form.getlist("resolves") + files = request.form.getlist("files") + names = request.form.getlist("file_names") + + validate_rev_data(title, summary, content, files, names) + + tags = validate_rev_tags(tag_names) + resolved_questions = validate_rev_resolved_questions(resolves) + + revision = create_rev(title, summary, content, post, tags) + + link_questions_to_post(resolved_questions, post) + save_files_for_revision(files, names, revision) + + db.session.commit() + + notify_resolved_questions(resolved_questions, post) + notify_new_post_rev(revision) + + return jsonify(serialize_revision(revision)) + + +@posts.route("//revisions/", + methods=["GET", "DELETE"]) +def post_revision(post_id, revision_id): + post = Post.query.filter(Post.id == post_id).one_or_none() + + if post is None: + abort(404) + + revision = PostRevision.query.filter( + PostRevision.id == revision_id).one_or_none() + + if revision is None: + abort(404) + + if request.method == "GET": + return get_post_revision(post, revision) + + elif request.method == "DELETE": + return delete_post_revision(post, revision) + + +def get_post_revision(post, revision): + return jsonify(serialize_revision(revision)) + + +def delete_post_revision(post, revision): + # If this is the only revision, delete the whole post + if len(post.revisions) == 1: + # Unlink any questions that were resolved by the post + for question in Question.query.filter(Question.post_id == id): + question.resolved_by = None + question.resolved = False + + Post.query.filter(Post.id == post.id).delete() + + else: + # If this is currently the latest revision, + # set the post's latest revision to the previous + # revision + if post.latest_rev.id == revision.id: + post.latest_rev = PostRevision.query \ + .filter(PostRevision.post_id == post.id) \ + .order_by(PostRevision.created_at.desc()) \ + .first() + + # Delete the revision + PostRevision.query \ + .filter(PostRevision.id == revision.id) \ + .delete() + + db.session.commit() + + +def validate_rev_data(title, summary, content, files, names): + # Check that required fields are present + if title is None or summary is None or content is None: + abort(400, message="`title`, `summary` and `content` \ + fields are required.") + + if not title: + abort(400, message="`title` field cannot be empty.") + + # Check that files and the associated names are valid + if len(files) != len(names): + abort(400, message="The number of files must match " + "the number of supplied names.") + + for name in names: + if len(name) > 200: + abort(400, message="file name must not be more than 200 \ + characters") + + if not allowed_file(name, + current_app.config["ALLOWED_FILE_EXTENSIONS"]): + abort(400, message=f"The file extension of {name} is " + "not allowed for security reasons. If " + "you believe that this file type is safe " + "to upload, contact the developer.") + + +def validate_rev_tags(tag_names): + if len(tag_names) > 0: + tags = Tag.query.filter(Tag.name.in_(tag_names)) + + if tags.count() < len(tag_names): + abort(400, message="Invalid tags - all tags must be \ + predefined through the tags api.") + + tags = tags.all() + + else: + tags = [] + + return tags + + +def validate_rev_resolved_questions(resolves): + resolved_questions = [] + + # Check that all resolved questions exist and are not resolved already + if resolves is not None: + if len(resolves) == 1 and ',' in resolves[0]: + resolves = resolves[0].split(',') + + for question_id in resolves: + question = Question.query.filter( + Question.id == question_id).one_or_none() + + if question is None: + abort(400, + message=f"Question with id `{question_id}` does \ + not exist.") + + if question.resolved: + abort(400, + message=f"Question with id `{question_id}` has \ + already been resolved.") + + resolved_questions.append(question) + + return resolved_questions + + +def create_rev(title, summary, content, post, tags): + revision = PostRevision(title=title, summary=summary, + content=content, post=post, tags=tags) + + db.session.add(revision) + db.session.commit() + + post.latest_rev = revision + db.session.commit() + + return revision + + +def link_questions_to_post(questions, post): + for question in questions: + question.resolved = True + question.resolved_by = post + + +def save_files_for_revision(files, names, revision): + for i in range(0, len(files)): + # Prefix file name with current time and random number to allow + # files with the same name + filename = datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f_') + \ + str(secrets.randbelow(10000000000)) + "_" + \ + werkzeug.utils.secure_filename(names[i]) + + path = os.path.join( + current_app.config['UPLOAD_FOLDER'], filename) + + if (os.path.isfile(path)): + abort(422, message="An unexpected file collision " + "ocurred. This error was thought to be " + "impossible to arise in practice. Please " + "contact the developer quoting this error " + "message.") + + files[i].save(path) + + file = File(name=names[i], filename=filename, post_revision=revision) + + db.session.add(file) + + +def notify_resolved_questions(resolved_questions, post): + for q in resolved_questions: + notifications.send_user( + q.user, "Your question has been resolved", + q.text, data={"id": post.id, "resolves": q.id}) + + +def notify_new_post_rev(revision): + notifications.broadcast(revision.title, revision.summary, + data={"id": revision.post.id}) diff --git a/drp/api/search.py b/drp/api/search.py index ff8506a..e1a91f4 100644 --- a/drp/api/search.py +++ b/drp/api/search.py @@ -8,7 +8,7 @@ from .posts import serialize_post -from ..models import Post, Post_Tag, Tag +# from ..models import Post, Post_Tag, Tag from ..db import db @@ -87,35 +87,35 @@ def get(self, searched): items: $ref: "#/definitions/Post" """ - if searched == "": - return abort(400, message="Empty string search is invalid.") - - page = request.args.get("page") - results_per_page = request.args.get("results_per_page") - guidelines_only = request.args.get("guidelines_only") - include_old = request.args.get("include_old") - tag = request.args.get("tag") - - ts_query, ts_rank = construct_fulltext_query_and_rank(searched) - - # Query for the search results ordered by rank - query = db.session.query(Post, ts_rank) \ - .filter(Post.__ts_vector__.op('@@')(ts_query)) - if (include_old != "true"): - query = query.filter(Post.is_current) - if guidelines_only == "true": - query = query.filter(Post.is_guideline) - if tag is not None: - query = query.join(Post_Tag).join(Tag).filter(Tag.name == tag) - query = query.order_by(text("rank desc"), Post.created_at.desc()) - - if page is None or results_per_page is None: - return extract_results_posts(query) - - if not page.isdigit() or not results_per_page.isdigit(): - return abort(400, message="Page and results_per_page fields must " - "be numbers.") - - query = limit_query(query, page, results_per_page) + # if searched == "": + # return abort(400, message="Empty string search is invalid.") + + # page = request.args.get("page") + # results_per_page = request.args.get("results_per_page") + # guidelines_only = request.args.get("guidelines_only") + # include_old = request.args.get("include_old") + # tag = request.args.get("tag") + + # ts_query, ts_rank = construct_fulltext_query_and_rank(searched) + + # # Query for the search results ordered by rank + # query = db.session.query(Post, ts_rank) \ + # .filter(Post.__ts_vector__.op('@@')(ts_query)) + # if (include_old != "true"): + # query = query.filter(Post.is_current) + # if guidelines_only == "true": + # query = query.filter(Post.is_guideline) + # if tag is not None: + # query = query.join(Post_Tag).join(Tag).filter(Tag.name == tag) + # query = query.order_by(text("rank desc"), Post.created_at.desc()) + + # if page is None or results_per_page is None: + # return extract_results_posts(query) + + # if not page.isdigit() or not results_per_page.isdigit(): + # return abort(400, message="Page and results_per_page fields must " + # "be numbers.") + + # query = limit_query(query, page, results_per_page) return extract_results_posts(query) diff --git a/drp/api/users.py b/drp/api/users.py index e1e3532..b1a2898 100644 --- a/drp/api/users.py +++ b/drp/api/users.py @@ -11,7 +11,7 @@ from ..models import User, UserRole from ..mail import mail -from .utils import require_auth, require_admin, error +from .utils import require_auth, require_admin, abort users = Blueprint("users", __name__) @@ -58,15 +58,13 @@ def create_user(): role = body.get("role") if email is None or password is None: - return error(400, - message="`email` and `password` fields" - "are required.") + abort(400, message="`email` and `password` fields are required.") if len(password) < 8: - return error(400, type="ShortPassword") + abort(400, type="ShortPassword") if role is not None and role != "normal" and role != "admin": - return error(400, message="`role` must be one of {normal, admin}.") + abort(400, message="`role` must be one of {normal, admin}.") role = UserRole.ADMIN if (role == "admin") else UserRole.NORMAL @@ -83,7 +81,7 @@ def create_user(): db.session.commit() except IntegrityError as err: if err.orig.pgcode == "23505": - return error( + abort( 422, message="A user with this email already exists.") else: raise @@ -104,7 +102,7 @@ def create_user(): @require_auth def by_id(id): if g.user.role != UserRole.ADMIN and g.user.id != id: - return error(401, type="InsufficientPermissions") + abort(401, type="InsufficientPermissions") if request.method == "GET": return get_user_by_id(id) @@ -123,7 +121,7 @@ def get_user_by_id(id): user = User.query.filter(User.id == id).one_or_none() if user is None: - return error(404) + abort(404) return serialize_user(user) @@ -135,7 +133,7 @@ def update_user_by_id(id): user = User.query.filter(User.id == id).one_or_none() if user is None: - return error(404) + abort(404) body = request.json @@ -144,7 +142,7 @@ def update_user_by_id(id): if password: if len(password) < 8: - return error(400, type="ShortPassword") + abort(400, type="ShortPassword") hasher = PasswordHasher() hash = hasher.hash(password) @@ -152,10 +150,10 @@ def update_user_by_id(id): if role: if g.user.role != UserRole.ADMIN: - return error(401, type="InsufficientPermission") + abort(401, type="InsufficientPermission") if role != "normal" and role != "admin": - return error(400, message="`role` must be one of {normal, admin}.") + abort(400, message="`role` must be one of {normal, admin}.") user.role = UserRole.ADMIN if (role == "admin") else UserRole.NORMAL @@ -171,7 +169,7 @@ def delete_user_by_id(id): user = User.query.filter(User.id == id).one_or_none() if user is None: - return error(404) + abort(404) db.session.delete(user) db.session.commit() diff --git a/drp/api/utils.py b/drp/api/utils.py index 91bd0b2..b52497d 100644 --- a/drp/api/utils.py +++ b/drp/api/utils.py @@ -1,20 +1,23 @@ from functools import wraps -from flask import request, g +from flask import request, g, jsonify, abort as flask_abort from ..models import User, UserRole from ..utils import decode_authorization_header -def error(code: int, message: str = None, type: str = None): - body = {} +def abort(code: int, message: str = None, type: str = None): + if not type: + type = "error" - if type: - body["type"] = type + body = {"type": type} if message: body["message"] = message - return body, code, {"Content-Type": "application/json"} + response = jsonify(body) + response.status_code = code + + flask_abort(response) def require_auth(f): @@ -22,13 +25,13 @@ def require_auth(f): def decorated_f(*args, **kwargs): claims = decode_authorization_header(request) if claims is None: - return error(401, type="InvalidAuthToken") + abort(401, type="InvalidAuthToken") g.user = User.query.filter( User.email == claims.get("sub")).one_or_none() if g.user is None: - return error(401, type="InvalidAuthToken") + abort(401, type="InvalidAuthToken") return f(*args, **kwargs) @@ -40,16 +43,16 @@ def require_admin(f): def decorated_f(*args, **kwargs): claims = decode_authorization_header(request) if claims is None: - return error(401, type="InvalidAuthToken") + abort(401, type="InvalidAuthToken") g.user = User.query.filter( User.email == claims.get("sub")).one_or_none() if g.user is None: - return error(401, type="InvalidAuthToken") + abort(401, type="InvalidAuthToken") if g.user.role != UserRole.ADMIN: - return error(401, type="InsufficientPermissions") + abort(401, type="InsufficientPermissions") return f(*args, **kwargs) diff --git a/drp/models/post.py b/drp/models/post.py index 823a488..77b2b19 100644 --- a/drp/models/post.py +++ b/drp/models/post.py @@ -21,15 +21,18 @@ class Post(db.Model): id = db.Column(db.Integer, primary_key=True) - is_guideline = db.Column(db.Boolean()) + is_guideline = db.Column(db.Boolean(), nullable=False) latest_rev_id = db.Column(db.Integer, db.ForeignKey( - "post_revisions.id", name="posts_latest_rev_id_fkey")) + "post_revisions.id", name="posts_latest_rev_id_fkey"), nullable=True) latest_rev = relationship("PostRevision", - back_populates="post", foreign_keys=[latest_rev_id]) + revisions = relationship("PostRevision", + foreign_keys="[PostRevision.post_id]", + back_populates="post") + resolves = relationship("Question", back_populates="resolved_by") @@ -38,7 +41,8 @@ class PostRevision(db.Model): id = db.Column(db.Integer, primary_key=True) post_id = db.Column(db.Integer, db.ForeignKey( - "posts.id", name="post_revisions_post_id_fkey", ondelete="CASCADE")) + "posts.id", name="post_revisions_post_id_fkey", ondelete="CASCADE"), + nullable=False) title = db.Column(db.Text(), nullable=False) summary = db.Column(db.Text(), nullable=False) @@ -47,7 +51,9 @@ class PostRevision(db.Model): created_at = db.Column(db.DateTime(timezone=True), nullable=False, server_default=func.now()) - post = relationship("Post", foreign_keys=[post_id]) + post = relationship("Post", + foreign_keys=[post_id], + back_populates="revisions") tags = relationship("Tag", secondary="post_rev_tag") files = relationship("File", back_populates="post_revision") @@ -105,7 +111,8 @@ class File(db.Model): filename = db.Column(db.String(300)) post_rev_id = db.Column(db.Integer, - db.ForeignKey("post_revisions.id")) + db.ForeignKey("post_revisions.id"), + nullable=False) post_revision = relationship('PostRevision', back_populates='files') diff --git a/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py b/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py index 10394ed..a6f5efc 100644 --- a/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py +++ b/migrations/versions/ef0a345b0640_separate_posts_into_posts_and_revisions.py @@ -113,7 +113,7 @@ def upgrade(): # Create the new posts table op.create_table("posts_new", sa.Column("id", sa.Integer(), nullable=False), - sa.Column("is_guideline", sa.Boolean(), nullable=True), + sa.Column("is_guideline", sa.Boolean(), nullable=False), sa.Column("latest_rev_id", sa.Integer(), nullable=True), sa.PrimaryKeyConstraint("id") ) @@ -121,11 +121,11 @@ def upgrade(): # Create the post revisions table op.create_table('post_revisions', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('post_id', sa.Integer(), nullable=True), + sa.Column('post_id', sa.Integer(), nullable=False), sa.Column('title', sa.Text(), nullable=False), sa.Column('summary', sa.Text(), nullable=False), sa.Column('content', sa.Text(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), + sa.Column('created_at', sa.DateTime(timezone=False), server_default=sa.text('now()'), nullable=False), sa.ForeignKeyConstraint( ['post_id'], ['posts_new.id'], ondelete="CASCADE"), @@ -188,6 +188,7 @@ def upgrade(): op.drop_table("post_tag") op.drop_column("files", "post_id") + op.alter_column("files", "post_rev_id", nullable=False) op.drop_column("questions", "post_id") op.alter_column("questions", "post_id_new", new_column_name="post_id") From 892bd927acc9eedeb3d20bd4a8d78494aaf3e73e Mon Sep 17 00:00:00 2001 From: Hasan Ali Date: Fri, 26 Jun 2020 20:06:32 +0100 Subject: [PATCH 04/11] Don't unresolve question when deleting linked post --- drp/api/posts.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/drp/api/posts.py b/drp/api/posts.py index a49295f..ff2d684 100644 --- a/drp/api/posts.py +++ b/drp/api/posts.py @@ -179,7 +179,6 @@ def get_post(id): def delete_post(id): for question in Question.query.filter(Question.post_id == id): question.resolved_by = None - question.resolved = False rows_deleted = Post.query.filter(Post.id == id).delete() @@ -279,7 +278,6 @@ def delete_post_revision(post, revision): # Unlink any questions that were resolved by the post for question in Question.query.filter(Question.post_id == id): question.resolved_by = None - question.resolved = False Post.query.filter(Post.id == post.id).delete() From db38ccfd71a4484a9d6422feef66c0d1cd228247 Mon Sep 17 00:00:00 2001 From: Hasan Ali Date: Fri, 26 Jun 2020 20:23:36 +0100 Subject: [PATCH 05/11] Return error for invalid tag in get_posts --- drp/api/posts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/drp/api/posts.py b/drp/api/posts.py index ff2d684..f9014ed 100644 --- a/drp/api/posts.py +++ b/drp/api/posts.py @@ -100,7 +100,7 @@ def get_posts(): if tag: tag = Tag.query.filter(Tag.name == tag).one_or_none() if not tag: - return jsonify([]) + abort(400, f"the tag `{tag}` does not exist") query = query.join(PostRev_Tag).join( Tag).filter(Tag.id == tag.id) From 476e5819c4fe474f958e5478bf903d2339071eea Mon Sep 17 00:00:00 2001 From: Hasan Ali Date: Fri, 26 Jun 2020 20:52:50 +0100 Subject: [PATCH 06/11] Remove unused files api --- drp/__init__.py | 3 - drp/api/__init__.py | 4 +- drp/api/files.py | 166 +------------------------------ tests/conftest.py | 10 +- tests/test_files.py | 236 -------------------------------------------- 5 files changed, 5 insertions(+), 414 deletions(-) delete mode 100644 tests/test_files.py diff --git a/drp/__init__.py b/drp/__init__.py index 28971b2..5a879a9 100644 --- a/drp/__init__.py +++ b/drp/__init__.py @@ -24,9 +24,6 @@ def init_api(app): api.add_resource(res.TagResource, "/api/tags/") api.add_resource(res.TagListResource, "/api/tags") - api.add_resource(res.FileResource, '/api/files/') - api.add_resource(res.FileListResource, "/api/files") - api.add_resource(res.RawFileViewResource, '/api/rawfiles/view/') api.add_resource(res.RawFileDownloadResource, '/api/rawfiles/download/') diff --git a/drp/api/__init__.py b/drp/api/__init__.py index 2688478..aef6e19 100644 --- a/drp/api/__init__.py +++ b/drp/api/__init__.py @@ -2,8 +2,7 @@ from .posts import posts from .search import PostSearchResource from .tags import TagListResource, TagResource -from .files import (FileResource, FileListResource, RawFileViewResource, - RawFileDownloadResource) +from .files import RawFileViewResource, RawFileDownloadResource from .questions import QuestionResource, QuestionListResource, questions from .site import SiteResource, SiteListResource from .subject import SubjectResource, SubjectListResource @@ -13,7 +12,6 @@ __all__ = ["posts", "PostSearchResource", "TagResource", "TagListResource", "QuestionResource", "QuestionListResource", - "FileResource", "FileListResource", "RawFileViewResource", "RawFileDownloadResource", "SiteResource", "SiteListResource", "SubjectResource", "SubjectListResource", diff --git a/drp/api/files.py b/drp/api/files.py index 1e86909..49f5dd8 100644 --- a/drp/api/files.py +++ b/drp/api/files.py @@ -1,12 +1,7 @@ -import os -import werkzeug -from datetime import datetime -from flask import current_app, request, send_from_directory +from flask import current_app, send_from_directory from flask_restful import Resource, abort -from ..db import db -from ..models import File, Post -from ..swag import swag +from ..models import File def allowed_file(filename, allowed): @@ -14,163 +9,6 @@ def allowed_file(filename, allowed): and filename.rsplit('.', 1)[1].lower() in allowed -@swag.definition("File") -def serialize_file(file): - """ - Represents an uploaded file. - --- - properties: - id: - type: integer - name: - type: string - post: - type: integer - """ - return { - "id": file.id, - "name": file.name, - "post_revision": file.post_rev_id - } - - -class FileResource(Resource): - - def get(self, id): - """ - Gets a single file by id. - --- - parameters: - - name: id - in: path - type: integer - required: true - responses: - 200: - schema: - $ref: "#/definitions/File" - 404: - description: Not found - """ - file = File.query.filter(File.id == id).one_or_none() - return serialize_file(file) if file is not None else abort(404) - - def delete(self, id): - """ - Deletes a single file by id. - --- - parameters: - - name: id - in: path - type: integer - required: true - responses: - 204: - description: Success - 404: - description: Not found - """ - file = File.query.filter(File.id == id).one_or_none() - - if file is None: - return abort(404) - - try: - os.remove(os.path.join( - current_app.config['UPLOAD_FOLDER'], file.filename)) - except OSError as e: - print("Could not delete file, " + repr(e)) - - db.session.delete(file) - db.session.commit() - - return '', 204 - - -class FileListResource(Resource): - - def get(self): - """ - Gets a list of all files. - --- - responses: - 200: - schema: - type: array - items: - $ref: "#/definitions/File" - """ - return [serialize_file(file) for file in File.query.all()] - - def post(self): - """ - Uploads a new file. - --- - parameters: - - in: formData - name: file - type: file - required: true - description: The file to upload. - - in: formData - name: name - type: string - required: true - description: The logical name of the file - - in: formData - name: post - type: string - required: true - description: The associated post - - responses: - 200: - schema: - $ref: "#/definitions/File" - """ - file_content = request.files.get('file') - name = request.form.get('name') - post_id = request.form.get('post') - - if name is None: - return abort(400, message="`name` field is required.") - - if file_content is None: - return abort(400, message="`file` filed is required.") - - if file_content == "": - return abort(400, message="A valid file is required.") - - if not allowed_file(name, - current_app.config['ALLOWED_FILE_EXTENSIONS']): - return abort(400, message=f"The file extension of {name} is " - "not allowed for security reasons. If " - "you believe that this file type is safe " - "to upload, contact the developer.") - - post = Post.query.filter(Post.id == post_id).one_or_none() - if post is None: - return abort(400, message="Invalid post ID, associated post must " - "already exist.") - - # Prefix file name with current time to allow mutliple files with the - # same name - filename = datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f_') + \ - werkzeug.utils.secure_filename(name) - path = os.path.join(current_app.config['UPLOAD_FOLDER'], filename) - if (os.path.isfile(path)): - return abort(422, message="a file upload collision occured, " - "please try again later") - file_content.save(path) - - file = File(name=name, filename=filename, post=post) - - db.session.add(file) - db.session.commit() - - return serialize_file(file) - - class RawFileViewResource(Resource): def get(self, id): diff --git a/tests/conftest.py b/tests/conftest.py index 71355c1..2cc1b7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,8 +17,9 @@ def app(): @pytest.fixture -def db(app, db_downgrade): +def db(app): with app.app_context(): + _db.drop_all() _db.create_all() yield _db @@ -27,13 +28,6 @@ def db(app, db_downgrade): _db.drop_all() -@pytest.fixture(scope="session") -def db_downgrade(app): - from flask_migrate import downgrade - with app.app_context(): - downgrade(revision="base") - - @pytest.fixture(autouse=True) def handle_upload(app): app.config["UPLOAD_FOLDER"] = os.path.join( diff --git a/tests/test_files.py b/tests/test_files.py deleted file mode 100644 index 2a0f6bd..0000000 --- a/tests/test_files.py +++ /dev/null @@ -1,236 +0,0 @@ -import json -import os -from io import BytesIO -from hashlib import sha256 - -from drp.models import File, Post - - -def create_test_post(app, db): - with app.app_context(): - post = Post(title="A title", summary="A summary", content="A content") - db.session.add(post) - db.session.commit() - post_id = post.id - return (post, post_id) - - -def create_files(app, db, files): - with app.app_context(): - for file in files: - db.session.add(file) - db.session.commit() - - -def test_get_all_files(app, db): - post, post_id = create_test_post(app, db) - name = "test.pdf" - filename = "file.pdf" - - count = 3 - create_files(app, db, [File(name=name, filename=filename, - post=post) - for i in range(0, count)]) - - with app.test_client() as client: - response = client.get('/api/files') - - assert "200" in response.status - - data = json.loads(response.data.decode("utf-8")) - - assert len(data) == count - - for file in data: - assert "id" in file - assert name == file["name"] - assert post_id == file["post"] - - -def test_create_file(app, db): - - tests_path = os.path.join(os.path.dirname(app.root_path), "tests") - input_path = os.path.join(tests_path, "input") - - with open(os.path.join(input_path, "Frankenstein.pdf"), 'rb') as file: - read = file.read() - input_hash = sha256(read).hexdigest() - file_bytes = BytesIO(read) - - _, post_id = create_test_post(app, db) - - file = { - "file": (file_bytes, "file1.pdf"), - "name": "Frankenstein.pdf", - "post": post_id - } - - with app.test_client() as client: - response = client.post('/api/files', - content_type='multipart/form-data', - data=file) - - assert "200" in response.status - - data = json.loads(response.data.decode("utf-8")) - - assert "id" in data - assert file["name"] == data["name"] - assert file["post"] == post_id - - post = Post.query.filter( - Post.id == post_id).one_or_none() - - assert data["id"] == post.files[0].id - assert file["name"] == post.files[0].name - - filename = File.query.filter( - File.id == data["id"]).one_or_none().filename - - file_path = os.path.join(tests_path, "output", filename) - - with open(file_path, 'rb') as file: - output_hash = sha256(file.read()).hexdigest() - assert input_hash == output_hash - - -def test_create_bad_file_type(app, db): - - _, post_id = create_test_post(app, db) - - file = { - "file": (BytesIO(b""), 'test.html'), - "name": "bad.html", - "post": post_id - } - - with app.test_client() as client: - response = client.post('/api/files', - content_type='multipart/form-data', - data=file) - - assert "400" in response.status - - data = json.loads(response.data.decode("utf-8")) - - assert "security" in data["message"] - assert "not allowed" in data["message"] - - -def test_create_file_invalid_post(app, db): - - file = { - "file": (BytesIO(b"A test"), 'test.pdf'), - "name": "test.pdf", - "post": 42 - } - - with app.test_client() as client: - response = client.post('/api/files', - content_type='multipart/form-data', - data=file) - - assert "400" in response.status - - data = json.loads(response.data.decode("utf-8")) - - assert "post ID" in data["message"] - - -def test_create_file_no_name(app, db): - - file = { - "file": (BytesIO(b"A test"), 'test.pdf'), - "post": 42 - } - - with app.test_client() as client: - response = client.post('/api/files', - content_type='multipart/form-data', - data=file) - - assert "400" in response.status - - data = json.loads(response.data.decode("utf-8")) - - assert "required" in data["message"] - - -def test_create_file_no_file(app, db): - - file = { - "name": "test.pdf", - "post": 42 - } - - with app.test_client() as client: - response = client.post('/api/files', - content_type='multipart/form-data', - data=file) - - assert "400" in response.status - - data = json.loads(response.data.decode("utf-8")) - - assert "required" in data["message"] - - -def test_create_file_no_post(app, db): - - file = { - "file": (BytesIO(b"A test"), 'test.pdf'), - "name": "test.pdf" - } - - with app.test_client() as client: - response = client.post('/api/files', - content_type='multipart/form-data', - data=file) - - assert "400" in response.status - - data = json.loads(response.data.decode("utf-8")) - - assert "post ID" in data["message"] - - -def test_delete_file(app, db): - - _, post_id = create_test_post(app, db) - - file = { - "file": (BytesIO(b"A test"), 'test.pdf'), - "name": "My test.pdf", - "post": post_id - } - - with app.test_client() as client: - response = client.post('/api/files', - content_type='multipart/form-data', - data=file) - - assert "200" in response.status - - data = json.loads(response.data.decode("utf-8")) - - filename = File.query.filter( - File.id == data["id"]).one_or_none().filename - file_path = os.path.join(os.path.dirname( - app.root_path), "tests", "output", filename) - - assert os.path.isfile(file_path) - - response = client.delete(f"/api/files/{data['id']}") - - assert "204" in response.status - - assert not os.path.isfile(file_path) - - -def test_delete_file_that_doesnt_exist(app, db): - - with app.test_client() as client: - - response = client.delete("/api/files/42") - - assert "404" in response.status From 8b30415972996f15753a6948ec94ae590bcf476e Mon Sep 17 00:00:00 2001 From: Hasan Ali Date: Sat, 27 Jun 2020 13:43:40 +0100 Subject: [PATCH 07/11] Fix posts api tests --- drp/api/posts.py | 39 +++++++-- tests/test_posts.py | 204 ++++++++++++++++++++++---------------------- 2 files changed, 134 insertions(+), 109 deletions(-) diff --git a/drp/api/posts.py b/drp/api/posts.py index f9014ed..2df6c05 100644 --- a/drp/api/posts.py +++ b/drp/api/posts.py @@ -128,7 +128,7 @@ def create_post(): tag_names = request.form.getlist("tags") type = request.form.get("type") resolves = request.form.getlist("resolves") - files = request.form.getlist("files") + files = request.files.getlist("files") names = request.form.getlist("file_names") validate_rev_data(title, summary, content, files, names) @@ -177,14 +177,27 @@ def get_post(id): def delete_post(id): - for question in Question.query.filter(Question.post_id == id): - question.resolved_by = None + post_query = Post.query.filter(Post.id == id) + post = post_query.one_or_none() - rows_deleted = Post.query.filter(Post.id == id).delete() - - if rows_deleted == 0: + if post is None: abort(404) + # Delete all files attached to any of the post's revisions + for revision in post.revisions: + for file in revision.files: + try: + os.remove(os.path.join( + current_app.config['UPLOAD_FOLDER'], file.filename)) + except OSError as e: + print("Could not delete file, " + repr(e)) + + db.session.delete(file) + + for question in Question.query.filter(Question.post_id == id): + question.resolved_by = None + + post_query.delete() db.session.commit() return jsonify({"message": "deleted"}), 204 @@ -226,7 +239,7 @@ def create_post_revision(post): content = request.form.get("content") tag_names = request.form.getlist("tags") resolves = request.form.getlist("resolves") - files = request.form.getlist("files") + files = request.files.getlist("files") names = request.form.getlist("file_names") validate_rev_data(title, summary, content, files, names) @@ -273,6 +286,16 @@ def get_post_revision(post, revision): def delete_post_revision(post, revision): + # Delete all files attached to the revision + for file in revision.files: + try: + os.remove(os.path.join( + current_app.config['UPLOAD_FOLDER'], file.filename)) + except OSError as e: + print("Could not delete file, " + repr(e)) + + db.session.delete(file) + # If this is the only revision, delete the whole post if len(post.revisions) == 1: # Unlink any questions that were resolved by the post @@ -298,6 +321,8 @@ def delete_post_revision(post, revision): db.session.commit() + return jsonify({"message": "deleted"}), 204 + def validate_rev_data(title, summary, content, files, names): # Check that required fields are present diff --git a/tests/test_posts.py b/tests/test_posts.py index e2ffe4a..1ddbafb 100644 --- a/tests/test_posts.py +++ b/tests/test_posts.py @@ -3,14 +3,29 @@ from io import BytesIO from hashlib import sha256 -from drp.models import Post, Tag, File +from drp.models import Post, PostRevision, Tag, File def create_posts(app, db, posts): with app.app_context(): - for post in posts: - db.session.add(post) - db.session.commit() + for p in posts: + create_post(app, db, p) + + +def create_post(app, db, data, is_guideline=False): + post = Post(is_guideline=is_guideline) + db.session.add(post) + db.session.commit() + + rev = PostRevision(title=data["title"], summary=data["summary"], + content=data["content"], post=post) + db.session.add(rev) + db.session.commit() + + post.latest_rev_id = rev.id + db.session.commit() + + return post.id def test_get_all_posts(app, db): @@ -19,22 +34,26 @@ def test_get_all_posts(app, db): content = "A few paragraphs of content..." count = 3 - create_posts(app, db, [Post(title=title, summary=summary, content=content) + create_posts(app, db, [{"title": title, + "summary": summary, + "content": content} for i in range(0, count)]) with app.test_client() as client: - response = client.get("/api/posts") + response = client.get("/api/posts/") assert "200" in response.status data = json.loads(response.data.decode("utf-8")) + print(data) + assert len(data) == count for post in data: - assert post["title"] == title - assert post["summary"] == summary - assert post["content"] == content + assert post["latest_revision"]["title"] == title + assert post["latest_revision"]["summary"] == summary + assert post["latest_revision"]["content"] == content def test_get_all_guidelines(app, db): @@ -44,13 +63,14 @@ def test_get_all_guidelines(app, db): content = "A few paragraphs of content..." with app.app_context(): - db.session.add(Post(title=title_ng, summary=summary, content=content)) - db.session.add(Post(title=title_g, summary=summary, - content=content, is_guideline=True)) - db.session.commit() + create_post(app, db, {"title": title_ng, + "summary": summary, "content": content}) + create_post(app, db, + {"title": title_g, "summary": summary, "content": content}, + is_guideline=True) with app.test_client() as client: - response = client.get("/api/posts?guidelines_only=true") + response = client.get("/api/posts/?type=guideline") assert "200" in response.status @@ -58,7 +78,7 @@ def test_get_all_guidelines(app, db): assert len(data) == 1 - assert data[0]["title"] == title_g + assert data[0]["latest_revision"]["title"] == title_g def test_create_post(app, db): @@ -69,7 +89,7 @@ def test_create_post(app, db): "content": "A few paragraphs of content..." } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -77,12 +97,12 @@ def test_create_post(app, db): data = json.loads(response.data.decode("utf-8")) - assert post["title"] == data["title"] - assert post["summary"] == data["summary"] - assert post["content"] == data["content"] + assert post["title"] == data["latest_revision"]["title"] + assert post["summary"] == data["latest_revision"]["summary"] + assert post["content"] == data["latest_revision"]["content"] - assert "id" in data - assert "created_at" in data + assert "id" in data["latest_revision"] + assert "created_at" in data["latest_revision"] def test_update_post(app, db): @@ -94,7 +114,7 @@ def test_update_post(app, db): "is_guideline": "true" } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -108,23 +128,26 @@ def test_update_post(app, db): "title": "A new title", "summary": "", "content": "", - "is_guideline": "true", - "updates": str(id) } - response = client.post('/api/posts', + response = client.post(f'/api/posts/{id}/revisions', content_type='multipart/form-data', data=update) - data = json.loads(response.data.decode("utf-8")) - assert "200" in response.status data = json.loads(response.data.decode("utf-8")) - assert id == data["id"] assert update["title"] == data["title"] + with app.app_context(): + revs = PostRevision.query.all() + + assert len(revs) == 2 + + for rev in revs: + assert rev.post_id == id + def test_create_post_with_missing_content(app, db): with app.test_client() as client: @@ -133,7 +156,7 @@ def test_create_post_with_missing_content(app, db): "summary": "A summary" } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -147,7 +170,7 @@ def test_create_post_with_missing_summary(app, db): "content": "A few paragraphs of content..." } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -175,7 +198,7 @@ def test_create_post_with_tags(app, db): "tags": ["Tag 1", "Tag 2"] } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -183,17 +206,17 @@ def test_create_post_with_tags(app, db): data = json.loads(response.data.decode("utf-8")) - assert post["title"] == data["title"] - assert post["content"] == data["content"] - assert post["summary"] == data["summary"] + assert post["title"] == data["latest_revision"]["title"] + assert post["content"] == data["latest_revision"]["content"] + assert post["summary"] == data["latest_revision"]["summary"] - assert "id" in data - assert "created_at" in data + assert "id" in data["latest_revision"] + assert "created_at" in data["latest_revision"] print(f"______TAGS_____: {post['tags']}") - assert {"id": id1, "name": "Tag 1"} in data["tags"] - assert {"id": id2, "name": "Tag 2"} in data["tags"] + assert {"id": id1, "name": "Tag 1"} in data["latest_revision"]["tags"] + assert {"id": id2, "name": "Tag 2"} in data["latest_revision"]["tags"] def test_create_post_with_files(app, db): @@ -216,10 +239,10 @@ def test_create_post_with_files(app, db): "summary": "A summary", "content": "A content", "files": [(file1, "file1.pdf"), (file2, "file2.png")], - "names": ["name1.pdf", "name2.jpg"] + "file_names": ["name1.pdf", "name2.jpg"] } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -227,23 +250,21 @@ def test_create_post_with_files(app, db): data = json.loads(response.data.decode("utf-8")) - assert post["title"] == data["title"] - assert post["content"] == data["content"] - assert post["summary"] == data["summary"] + assert post["title"] == data["latest_revision"]["title"] + assert post["content"] == data["latest_revision"]["content"] + assert post["summary"] == data["latest_revision"]["summary"] - assert "id" in data - assert "created_at" in data + assert "id" in data["latest_revision"] + assert "created_at" in data["latest_revision"] - assert len(data["files"]) == 2 + assert len(data["latest_revision"]["files"]) == 2 - files = data["files"] + files = data["latest_revision"]["files"] - assert post["names"][0] == files[0]["name"] - assert data["id"] == files[0]["post"] + assert post["file_names"][0] == files[0]["name"] assert "id" in files[0] - assert post["names"][1] == files[1]["name"] - assert data["id"] == files[1]["post"] + assert post["file_names"][1] == files[1]["name"] assert "id" in files[1] filename1 = File.query.filter( @@ -270,7 +291,7 @@ def test_create_post_with_missing_title(app, db): "content": "A few paragraphs of content..." } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -285,10 +306,10 @@ def test_create_post_with_bad_file_type(app, db): "summary": "A summary", "content": "A content", "files": [(BytesIO(b""), 'test.html')], - "names": ["name1.html"] + "file_names": ["name1.html"] } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -308,10 +329,10 @@ def test_create_post_with_files_names_mismatch(app, db): "summary": "A summary", "content": "A content", "files": [(BytesIO(b"A test file"), 'test.pdf')], - "names": ["name1.pdf", "name2.png"] + "file_names": ["name1.pdf", "name2.png"] } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -328,25 +349,19 @@ def test_get_single_post(app, db): content = "A few paragraphs of content..." with app.app_context(): - post = Post(title=title, summary=summary, content=content) - db.session.add(post) - db.session.commit() - id = post.id + id = create_post(app, db, {"title": title, + "summary": summary, "content": content}) with app.test_client() as client: response = client.get(f"/api/posts/{id}") assert "200" in response.status - posts = json.loads(response.data.decode("utf-8")) - - assert len(posts) == 1 + post = json.loads(response.data.decode("utf-8")) - post = posts[0] - - assert post["title"] == title - assert post["summary"] == summary - assert post["content"] == content + assert post["latest_revision"]["title"] == title + assert post["latest_revision"]["summary"] == summary + assert post["latest_revision"]["content"] == content def test_get_single_post_that_doesnt_exist(app, db): @@ -361,10 +376,8 @@ def test_delete_post(app, db): content = "A few paragraphs of content..." with app.app_context(): - post = Post(title=title, summary=summary, content=content) - db.session.add(post) - db.session.commit() - id = post.id + id = create_post(app, db, {"title": title, + "summary": summary, "content": content}) with app.test_client() as client: response = client.delete(f"/api/posts/{id}") @@ -381,34 +394,25 @@ def test_delete_last_revision(app, db): content = "A few paragraphs of content..." with app.app_context(): - old = Post(title=title_old, summary=summary, - content=content, is_guideline=True) - db.session.add(old) - db.session.commit() - id = old.post_id - new = Post(title=title_new, summary=summary, - content=content, is_guideline=True, post_id=id) - old.is_current = False + id = create_post(app, db, {"title": title_old, + "summary": summary, "content": content}) + new = PostRevision(title=title_new, summary=summary, + content=content, post_id=id) db.session.add(new) db.session.commit() - revision_id = new.id + rev_id = new.id with app.test_client() as client: - response = client.delete(f"/api/revisions/{revision_id}") + response = client.delete(f"/api/posts/{id}/revisions/{rev_id}") assert "204" in response.status response = client.get(f"/api/posts/{id}") assert "200" in response.status - data = json.loads(response.data.decode("utf-8")) - - assert len(data) == 1 + post = json.loads(response.data.decode("utf-8")) - post = data[0] - - assert post["is_current"] - assert post["title"] == title_old + assert post["latest_revision"]["title"] == title_old assert post["id"] == id @@ -426,10 +430,10 @@ def test_delete_post_with_file(app, db): "summary": "A summary", "content": "A content", "files": [(file_bytes, "file1.pdf")], - "names": ["name1.pdf"] + "file_names": ["name1.pdf"] } - response = client.post('/api/posts', + response = client.post('/api/posts/', content_type='multipart/form-data', data=post) @@ -438,7 +442,7 @@ def test_delete_post_with_file(app, db): data = json.loads(response.data.decode("utf-8")) filename = File.query.filter( - File.id == data["files"][0]["id"]).one_or_none().filename + File.id == data["latest_revision"]["files"][0]["id"]).one_or_none().filename file_path = os.path.join(tests_path, "output", filename) assert os.path.isfile(file_path) @@ -456,10 +460,8 @@ def test_delete_single_post_that_doesnt_exist(app, db): content = "A few paragraphs of content..." with app.app_context(): - post = Post(title=title, summary=summary, content=content) - db.session.add(post) - db.session.commit() - id = post.id + id = create_post(app, db, {"title": title, + "summary": summary, "content": content}) with app.test_client() as client: response = client.delete(f"/api/posts/{id + 1}") @@ -479,22 +481,20 @@ def test_timezone_utc(app, db): content = "A few paragraphs of content..." with app.app_context(): - post = Post(title=title, summary=summary, content=content) - db.session.add(post) - db.session.commit() - id = post.id + id = create_post(app, db, {"title": title, + "summary": summary, "content": content}) with app.test_client() as client: response = client.get(f"/api/posts/{id}") assert "200" in response.status - post = json.loads(response.data.decode("utf-8"))[0] + post = json.loads(response.data.decode("utf-8")) from datetime import datetime, timedelta utc_now = datetime.utcnow() created_at = datetime.fromisoformat( - post["created_at"]).replace(tzinfo=None) + post["latest_revision"]["created_at"]).replace(tzinfo=None) assert utc_now - created_at < timedelta(milliseconds=5000) From 077adc9167a86a80830373147a71b2b28e2acb43 Mon Sep 17 00:00:00 2001 From: Hasan Ali Date: Sat, 27 Jun 2020 13:45:49 +0100 Subject: [PATCH 08/11] Fix long line lint --- tests/test_posts.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_posts.py b/tests/test_posts.py index 1ddbafb..12c952c 100644 --- a/tests/test_posts.py +++ b/tests/test_posts.py @@ -441,8 +441,10 @@ def test_delete_post_with_file(app, db): data = json.loads(response.data.decode("utf-8")) - filename = File.query.filter( - File.id == data["latest_revision"]["files"][0]["id"]).one_or_none().filename + filename = File.query \ + .filter(File.id == data["latest_revision"]["files"][0]["id"]) \ + .one_or_none() \ + .filename file_path = os.path.join(tests_path, "output", filename) assert os.path.isfile(file_path) From 55058e755998fd02efb5f36fd619093aeb017757 Mon Sep 17 00:00:00 2001 From: Adam Dejl Date: Sat, 27 Jun 2020 16:39:40 +0200 Subject: [PATCH 09/11] Fix search --- drp/api/search.py | 87 +++++++++++++++++--------------- tests/test_search.py | 117 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 145 insertions(+), 59 deletions(-) diff --git a/drp/api/search.py b/drp/api/search.py index e1a91f4..6451501 100644 --- a/drp/api/search.py +++ b/drp/api/search.py @@ -8,7 +8,7 @@ from .posts import serialize_post -# from ..models import Post, Post_Tag, Tag +from ..models import Post, PostRevision, PostRev_Tag, Tag from ..db import db @@ -33,7 +33,8 @@ def construct_fulltext_query_and_rank(searched): # Final text search query ts_query = func.to_tsquery('english', prefix_ts_query_text) # Rank for each search result - ts_rank = func.ts_rank_cd(Post.__ts_vector__, ts_query).label("rank") + ts_rank = func.ts_rank_cd( + PostRevision.__ts_vector__, ts_query).label("rank") return (ts_query, ts_rank) @@ -68,14 +69,13 @@ def get(self, searched): in: query type: number required: false - - name: guidelines_only + - name: type in: query - type: boolean - required: false - - name: include_old - in: query - type: boolean - required: false + type: string + enum: + - any + - update + - guideline - name: tag in: query type: string @@ -87,35 +87,44 @@ def get(self, searched): items: $ref: "#/definitions/Post" """ - # if searched == "": - # return abort(400, message="Empty string search is invalid.") - - # page = request.args.get("page") - # results_per_page = request.args.get("results_per_page") - # guidelines_only = request.args.get("guidelines_only") - # include_old = request.args.get("include_old") - # tag = request.args.get("tag") - - # ts_query, ts_rank = construct_fulltext_query_and_rank(searched) - - # # Query for the search results ordered by rank - # query = db.session.query(Post, ts_rank) \ - # .filter(Post.__ts_vector__.op('@@')(ts_query)) - # if (include_old != "true"): - # query = query.filter(Post.is_current) - # if guidelines_only == "true": - # query = query.filter(Post.is_guideline) - # if tag is not None: - # query = query.join(Post_Tag).join(Tag).filter(Tag.name == tag) - # query = query.order_by(text("rank desc"), Post.created_at.desc()) - - # if page is None or results_per_page is None: - # return extract_results_posts(query) - - # if not page.isdigit() or not results_per_page.isdigit(): - # return abort(400, message="Page and results_per_page fields must " - # "be numbers.") - - # query = limit_query(query, page, results_per_page) + if searched == "": + return abort(400, message="Empty string search is invalid.") + + page = request.args.get("page") + results_per_page = request.args.get("results_per_page") + type = request.args.get("type") + tag = request.args.get("tag") + + ts_query, ts_rank = construct_fulltext_query_and_rank(searched) + + # Query for the search results ordered by rank + query = db.session.query(Post, ts_rank) \ + .join(Post.latest_rev) \ + .filter(PostRevision.__ts_vector__.op('@@')(ts_query)) + + if type == "update": + query = query.filter(Post.is_guideline == False) # noqa + elif type == "guideline": + query = query.filter(Post.is_guideline) + + if tag is not None: + tag = Tag.query.filter(Tag.name == tag).one_or_none() + + if not tag: + return abort(400, f"the tag `{tag}` does not exist") + + query = query.join(PostRev_Tag).join(Tag).filter(Tag.id == tag.id) + + query = query.order_by( + text("rank desc"), PostRevision.created_at.desc()) + + if page is None or results_per_page is None: + return extract_results_posts(query) + + if not page.isdigit() or not results_per_page.isdigit(): + return abort(400, message="Page and results_per_page fields " + "must be numbers.") + + query = limit_query(query, page, results_per_page) return extract_results_posts(query) diff --git a/tests/test_search.py b/tests/test_search.py index 287fe09..0e1183f 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,6 +1,6 @@ import json -from drp.models import Post +from drp.models import Post, PostRevision, Tag def add_test_posts(app, db): @@ -53,16 +53,46 @@ def add_test_posts(app, db): { "title": "Test 3", "summary": "Test summary", - "content": "Alpha beta" + "content": "Alpha beta", } ] with app.app_context(): - for post in posts: - post = Post(title=post["title"], - summary=post["summary"], content=post["content"]) - db.session.add(post) + t1 = Tag(name="Tag 1") + t2 = Tag(name="Tag 2") + + db.session.add(t1) + db.session.add(t2) db.session.commit() + for post_data in posts: + if post_data["title"] == "Test 1" \ + or post_data["title"] == "Test 3": + post = Post(is_guideline=True) + else: + post = Post(is_guideline=False) + db.session.add(post) + db.session.commit() + + if post_data["title"] == "Test 1": + post_revision = PostRevision( + title=post_data["title"], summary=post_data["summary"], + content=post_data["content"], post=post, tags=[t1, t2]) + elif post_data["title"] == "Test 3": + post_revision = PostRevision( + title=post_data["title"], summary=post_data["summary"], + content=post_data["content"], post=post, tags=[t1]) + else: + post_revision = PostRevision( + title=post_data["title"], summary=post_data["summary"], + content=post_data["content"], post=post) + + db.session.add(post_revision) + db.session.commit() + + post.latest_rev = post_revision + + db.session.commit() + def test_search_single_content(app, db): add_test_posts(app, db) @@ -75,8 +105,8 @@ def test_search_single_content(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 2 - assert "beginning" in posts[0]["title"] - assert "Turtle" in posts[1]["title"] + assert "beginning" in posts[0]["latest_revision"]["title"] + assert "Turtle" in posts[1]["latest_revision"]["title"] def test_search_two_content(app, db): @@ -90,8 +120,8 @@ def test_search_two_content(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 2 - assert "beginning" in posts[0]["title"] - assert "Turtle" in posts[1]["title"] + assert "beginning" in posts[0]["latest_revision"]["title"] + assert "Turtle" in posts[1]["latest_revision"]["title"] def test_search_three_across_columns(app, db): @@ -105,7 +135,7 @@ def test_search_three_across_columns(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 1 - assert "beginning" in posts[0]["title"] + assert "beginning" in posts[0]["latest_revision"]["title"] def test_search_order_by_rank(app, db): @@ -119,9 +149,9 @@ def test_search_order_by_rank(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 3 - assert posts[0]["title"] == "Test 3" - assert posts[1]["title"] == "Test 1" - assert posts[2]["title"] == "Test 2" + assert posts[0]["latest_revision"]["title"] == "Test 3" + assert posts[1]["latest_revision"]["title"] == "Test 1" + assert posts[2]["latest_revision"]["title"] == "Test 2" def test_search_form(app, db): @@ -135,7 +165,7 @@ def test_search_form(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 1 - assert "beginning" in posts[0]["title"] + assert "beginning" in posts[0]["latest_revision"]["title"] def test_search_prefix(app, db): @@ -149,7 +179,7 @@ def test_search_prefix(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 1 - assert "beginning" in posts[0]["title"] + assert "beginning" in posts[0]["latest_revision"]["title"] response = client.get("/api/search/posts/conta") @@ -158,7 +188,7 @@ def test_search_prefix(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 1 - assert "Turtle" in posts[0]["title"] + assert "Turtle" in posts[0]["latest_revision"]["title"] def test_search_stop_word(app, db): @@ -219,8 +249,8 @@ def test_search_first_page(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 2 - assert posts[0]["title"] == "Test 3" - assert posts[1]["title"] == "Test 1" + assert posts[0]["latest_revision"]["title"] == "Test 3" + assert posts[1]["latest_revision"]["title"] == "Test 1" def test_search_second_page(app, db): @@ -235,7 +265,7 @@ def test_search_second_page(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 1 - assert posts[0]["title"] == "Test 2" + assert posts[0]["latest_revision"]["title"] == "Test 2" def test_search_high_page(app, db): @@ -250,3 +280,50 @@ def test_search_high_page(app, db): posts = json.loads(response.data.decode("utf-8")) assert len(posts) == 0 + + +def test_search_guidelines_only(app, db): + add_test_posts(app, db) + + with app.test_client() as client: + response = client.get("/api/search/posts/alpha beta?type=guideline") + + assert "200" in response.status + + posts = json.loads(response.data.decode("utf-8")) + + assert len(posts) == 2 + + +def test_search_updates_only(app, db): + add_test_posts(app, db) + + with app.test_client() as client: + response = client.get("/api/search/posts/alpha beta?type=update") + + assert "200" in response.status + + posts = json.loads(response.data.decode("utf-8")) + + assert len(posts) == 1 + + +def test_search_tags(app, db): + add_test_posts(app, db) + + with app.test_client() as client: + response = client.get("/api/search/posts/alpha beta?tag=Tag 1") + + assert "200" in response.status + + posts = json.loads(response.data.decode("utf-8")) + + assert len(posts) == 2 + + response = client.get("/api/search/posts/alpha beta?tag=Tag 2") + + assert "200" in response.status + + posts = json.loads(response.data.decode("utf-8")) + + assert len(posts) == 1 From 0ce61743110f74534d06eb29a5084debd671005a Mon Sep 17 00:00:00 2001 From: Adam Dejl Date: Sat, 27 Jun 2020 17:03:13 +0200 Subject: [PATCH 10/11] Remove duplicate join --- drp/api/posts.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/drp/api/posts.py b/drp/api/posts.py index 2df6c05..b421666 100644 --- a/drp/api/posts.py +++ b/drp/api/posts.py @@ -68,8 +68,10 @@ def get_posts(): page = request.args.get("page") per_page = request.args.get("per_page") - query = Post.query.join(Post.latest_rev).options( - joinedload("latest_rev").options( + query = Post.query.join(Post.latest_rev) \ + .options( + joinedload("latest_rev") + .options( joinedload("tags"), joinedload("files"))) @@ -104,7 +106,7 @@ def get_posts(): query = query.join(PostRev_Tag).join( Tag).filter(Tag.id == tag.id) - query = query.join(Post.latest_rev).order_by( + query = query.order_by( PostRevision.created_at.desc()) # Pagination From 45fe2225a5b0ebda5d3207111364a28999217ebf Mon Sep 17 00:00:00 2001 From: Hasan Ali Date: Sat, 27 Jun 2020 18:37:10 +0100 Subject: [PATCH 11/11] Remove trailing slash from posts api root --- drp/api/posts.py | 8 +++----- tests/test_posts.py | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/drp/api/posts.py b/drp/api/posts.py index b421666..156c46d 100644 --- a/drp/api/posts.py +++ b/drp/api/posts.py @@ -52,7 +52,7 @@ def serialize_post(post): } -@posts.route("/", methods=["GET", "POST"]) +@posts.route("", methods=["GET", "POST"]) def all_posts(): if request.method == "GET": return get_posts() @@ -68,10 +68,8 @@ def get_posts(): page = request.args.get("page") per_page = request.args.get("per_page") - query = Post.query.join(Post.latest_rev) \ - .options( - joinedload("latest_rev") - .options( + query = Post.query.join(Post.latest_rev).options( + joinedload("latest_rev").options( joinedload("tags"), joinedload("files"))) diff --git a/tests/test_posts.py b/tests/test_posts.py index 12c952c..ff238d0 100644 --- a/tests/test_posts.py +++ b/tests/test_posts.py @@ -40,7 +40,7 @@ def test_get_all_posts(app, db): for i in range(0, count)]) with app.test_client() as client: - response = client.get("/api/posts/") + response = client.get("/api/posts") assert "200" in response.status @@ -70,7 +70,7 @@ def test_get_all_guidelines(app, db): is_guideline=True) with app.test_client() as client: - response = client.get("/api/posts/?type=guideline") + response = client.get("/api/posts?type=guideline") assert "200" in response.status @@ -89,7 +89,7 @@ def test_create_post(app, db): "content": "A few paragraphs of content..." } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post) @@ -114,7 +114,7 @@ def test_update_post(app, db): "is_guideline": "true" } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post) @@ -156,7 +156,7 @@ def test_create_post_with_missing_content(app, db): "summary": "A summary" } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post) @@ -170,7 +170,7 @@ def test_create_post_with_missing_summary(app, db): "content": "A few paragraphs of content..." } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post) @@ -198,7 +198,7 @@ def test_create_post_with_tags(app, db): "tags": ["Tag 1", "Tag 2"] } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post) @@ -242,7 +242,7 @@ def test_create_post_with_files(app, db): "file_names": ["name1.pdf", "name2.jpg"] } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post) @@ -291,7 +291,7 @@ def test_create_post_with_missing_title(app, db): "content": "A few paragraphs of content..." } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post) @@ -309,7 +309,7 @@ def test_create_post_with_bad_file_type(app, db): "file_names": ["name1.html"] } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post) @@ -332,7 +332,7 @@ def test_create_post_with_files_names_mismatch(app, db): "file_names": ["name1.pdf", "name2.png"] } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post) @@ -433,7 +433,7 @@ def test_delete_post_with_file(app, db): "file_names": ["name1.pdf"] } - response = client.post('/api/posts/', + response = client.post('/api/posts', content_type='multipart/form-data', data=post)