test_api_workspaced_base.py 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
#-*- coding: utf8 -*-
'''
Faraday Penetration Test IDE
Copyright (C) 2013  Infobyte LLC (http://www.infobytesec.com/)
See the file 'doc/LICENSE' for the license information

'''

"""Generic tests for APIs prefixed with a workspace_name"""

import pytest
from sqlalchemy.orm.util import was_deleted
from server.models import db, Workspace, Credential
from test_api_pagination import PaginationTestsMixin as \
    OriginalPaginationTestsMixin

API_PREFIX = '/v2/ws/'
OBJECT_COUNT = 5


@pytest.mark.usefixtures('logged_user')
class GenericAPITest:

    model = None
    factory = None
    api_endpoint = None
    pk_field = 'id'
    unique_fields = []
    update_fields = []

    @pytest.fixture(autouse=True)
    def load_workspace_with_objects(self, database, session, workspace):
        self.objects = self.factory.create_batch(
            OBJECT_COUNT, workspace=workspace)
        self.first_object = self.objects[0]
        session.add_all(self.objects)
        session.commit()
        assert workspace.id is not None
        self.workspace = workspace
        return workspace

    @pytest.fixture
    def object_instance(self, session, workspace):
        """An object instance with the correct workspace assigned,
        saved in the database"""
        obj = self.factory.create(workspace=workspace)
        session.commit()
        return obj

    def url(self, obj=None, workspace=None):
        workspace = workspace or self.workspace
        url = API_PREFIX + workspace.name + '/' + self.api_endpoint + '/'
        if obj is not None:
            id_ = unicode(obj.id) if isinstance(
                obj, self.model) else unicode(obj)
            url += id_ + u'/'
        return url


class ListTestsMixin:
    view_class = None  # Must be overriden

    @pytest.fixture
    def mock_envelope_list(self, monkeypatch):
        assert self.view_class is not None, 'You must define view_class ' \
            'in order to use ListTestsMixin or PaginationTestsMixin'
        def _envelope_list(self, objects, pagination_metadata=None):
            return {"data": objects}
        monkeypatch.setattr(self.view_class, '_envelope_list', _envelope_list)

    @pytest.mark.usefixtures('mock_envelope_list')
    def test_list_retrieves_all_items_from_workspace(self, test_client,
                                                     second_workspace,
                                                     session):
        obj = self.factory.create(workspace=second_workspace)
        session.add(obj)
        session.commit()
        res = test_client.get(self.url())
        assert res.status_code == 200
        assert len(res.json['data']) == OBJECT_COUNT

82 83 84 85 86
    def test_can_list_readonly(self, test_client, session):
        self.workspace.readonly = True
        session.commit()
        res = test_client.get(self.url())
        assert res.status_code == 200
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

class RetrieveTestsMixin:

    def test_retrieve_one_object(self, test_client):
        res = test_client.get(self.url(self.first_object))
        assert res.status_code == 200
        assert isinstance(res.json, dict)

    def test_retrieve_fails_object_of_other_workspcae(self,
                                                      test_client,
                                                      session,
                                                      second_workspace):
        res = test_client.get(self.url(self.first_object, second_workspace))
        assert res.status_code == 404

    @pytest.mark.parametrize('object_id', [12345, -1, 'xxx', u'áá'])
    def test_404_when_retrieving_unexistent_object(self, test_client,
                                                   object_id):
        url = self.url(object_id)
        res = test_client.get(url)
        assert res.status_code == 404


class CreateTestsMixin:

    def test_create_succeeds(self, test_client):
        data = self.factory.build_dict(workspace=self.workspace)
        res = test_client.post(self.url(),
                               data=data)
        assert res.status_code == 201, (res.status_code, res.data)
        assert self.model.query.count() == OBJECT_COUNT + 1
        object_id = res.json['id']
        obj = self.model.query.get(object_id)
        assert obj.workspace == self.workspace

122 123 124 125 126 127 128 129 130 131 132
    def test_create_fails_readonly(self, test_client):
        self.workspace.readonly = True
        db.session.commit()
        data = self.factory.build_dict(workspace=self.workspace)
        res = test_client.post(self.url(),
                               data=data)
        db.session.commit()
        assert res.status_code == 403
        assert self.model.query.count() == OBJECT_COUNT


133 134 135 136 137 138 139 140 141
    def test_create_inactive_fails(self, test_client):
        self.workspace.deactivate()
        db.session.commit()
        data = self.factory.build_dict(workspace=self.workspace)
        res = test_client.post(self.url(),
                               data=data)
        assert res.status_code == 403, (res.status_code, res.data)
        assert self.model.query.count() == OBJECT_COUNT

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    def test_create_fails_with_empty_dict(self, test_client):
        res = test_client.post(self.url(), data={})
        assert res.status_code == 400

    def test_create_fails_with_existing(self, session, test_client):
        for unique_field in self.unique_fields:
            data = self.factory.build_dict()
            data[unique_field] = getattr(self.first_object, unique_field)
            res = test_client.post(self.url(), data=data)
            assert res.status_code == 409
            assert self.model.query.count() == OBJECT_COUNT

    def test_create_with_existing_in_other_workspace(self, test_client,
                                                     session,
                                                     second_workspace):
        if not self.unique_fields:
            return
        unique_field = self.unique_fields[0]
        other_object = self.factory.create(workspace=second_workspace)
        session.commit()

        data = self.factory.build_dict()
        data[unique_field] = getattr(other_object, unique_field)
        res = test_client.post(self.url(), data=data)
        assert res.status_code == 201
        # It should create two hosts, one for each workspace
        assert self.model.query.count() == OBJECT_COUNT + 2


class UpdateTestsMixin:

    def test_update_an_object(self, test_client):
        data = self.factory.build_dict(workspace=self.workspace)
        res = test_client.put(self.url(self.first_object),
                              data=data)
        assert res.status_code == 200
        assert self.model.query.count() == OBJECT_COUNT
        for updated_field in self.update_fields:
            assert res.json[updated_field] == getattr(self.first_object,
                                                      updated_field)

183 184 185 186 187 188 189 190 191 192 193 194 195
    def test_update_an_object_readonly_fails(self, test_client):
        self.workspace.readonly = True
        db.session.commit()
        for unique_field in self.unique_fields:
            data = self.factory.build_dict()
            old_field = getattr(self.objects[0], unique_field)
            old_id = getattr(self.objects[0], 'id')
            res = test_client.put(self.url(self.first_object), data=data)
            db.session.commit()
            assert res.status_code == 403
            assert self.model.query.count() == OBJECT_COUNT
            assert old_field == getattr(self.model.query.filter(self.model.id == old_id).one(), unique_field)

196 197 198 199 200 201 202 203 204
    def test_update_inactive_fails(self, test_client):
        self.workspace.deactivate()
        db.session.commit()
        data = self.factory.build_dict(workspace=self.workspace)
        res = test_client.put(self.url(self.first_object),
                               data=data)
        assert res.status_code == 403
        assert self.model.query.count() == OBJECT_COUNT

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    def test_update_fails_with_existing(self, test_client, session):
        for unique_field in self.unique_fields:
            data = self.factory.build_dict()
            data[unique_field] = getattr(self.objects[1], unique_field)
            res = test_client.put(self.url(self.first_object), data=data)
            assert res.status_code == 409
            assert self.model.query.count() == OBJECT_COUNT

    def test_update_an_object_fails_with_empty_dict(self, test_client):
        """To do this the user should use a PATCH request"""
        res = test_client.put(self.url(self.first_object), data={})
        assert res.status_code == 400

    def test_update_cant_change_id(self, test_client):
        raw_json = self.factory.build_dict(workspace=self.workspace)
        expected_id = self.first_object.id
        raw_json['id'] = 100000
        res = test_client.put(self.url(self.first_object),
                              data=raw_json)
        assert res.status_code == 200
        assert res.json['id'] == expected_id


class CountTestsMixin:
    pass


class DeleteTestsMixin:

    def test_delete(self, test_client):
        res = test_client.delete(self.url(self.first_object))
        assert res.status_code == 204  # No content
        assert was_deleted(self.first_object)
        assert self.model.query.count() == OBJECT_COUNT - 1

240 241 242 243 244 245 246 247
    def test_delete_readonly_fails(self, test_client, session):
        self.workspace.readonly = True
        session.commit()
        res = test_client.delete(self.url(self.first_object))
        assert res.status_code == 403  # No content
        assert not was_deleted(self.first_object)
        assert self.model.query.count() == OBJECT_COUNT

248 249 250 251 252 253 254 255
    def test_delete_inactive_fails(self, test_client):
        self.workspace.deactivate()
        db.session.commit()
        res = test_client.delete(self.url(self.first_object))
        assert res.status_code == 403
        assert not was_deleted(self.first_object)
        assert self.model.query.count() == OBJECT_COUNT

256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
    def test_delete_from_other_workspace_fails(self, test_client,
                                                    second_workspace):
        res = test_client.delete(self.url(self.first_object,
                                          workspace=second_workspace))
        assert res.status_code == 404  # No content
        assert not was_deleted(self.first_object)
        assert self.model.query.count() == OBJECT_COUNT


class PaginationTestsMixin(OriginalPaginationTestsMixin):
    def create_many_objects(self, session, n):
        objects = self.factory.create_batch(n, workspace=self.workspace)
        session.commit()
        return objects


class ReadWriteTestsMixin(ListTestsMixin,
                          RetrieveTestsMixin,
                          CreateTestsMixin,
                          CountTestsMixin,
                          UpdateTestsMixin,
                          DeleteTestsMixin):
    pass


class ReadWriteAPITests(ReadWriteTestsMixin,
                        GenericAPITest):
    pass


class ReadOnlyAPITests(ListTestsMixin,
                       RetrieveTestsMixin,
                       GenericAPITest):
    pass