diff --git a/rdmo/core/constants.py b/rdmo/core/constants.py index d78c29f4a5..d2a92c5340 100644 --- a/rdmo/core/constants.py +++ b/rdmo/core/constants.py @@ -81,15 +81,15 @@ } RDMO_MODELS = { - 'catalog': 'questions.catalog', - 'section': 'questions.section', - 'page': 'questions.page', - 'questionset': 'questions.questionset', - 'question': 'questions.question', 'attribute': 'domain.attribute', - 'optionset': 'options.optionset', - 'option': 'options.option', 'condition': 'conditions.condition', + 'option': 'options.option', + 'optionset': 'options.optionset', + 'question': 'questions.question', + 'questionset': 'questions.questionset', + 'page': 'questions.page', + 'section': 'questions.section', + 'catalog': 'questions.catalog', 'task': 'tasks.task', 'view': 'views.view' } diff --git a/rdmo/core/import_helpers.py b/rdmo/core/import_helpers.py index 4d2be019b1..544db24971 100644 --- a/rdmo/core/import_helpers.py +++ b/rdmo/core/import_helpers.py @@ -33,6 +33,8 @@ def get_value(self, **kwargs): @staticmethod def get_value_from_callback(callback, kwargs): + if not callable(callback): + raise TypeError('callback must be callable') sig = signature(callback) kwargs = {k: val for k, val in kwargs.items() if k in sig.parameters} value = callback(**kwargs) diff --git a/rdmo/core/imports.py b/rdmo/core/imports.py index 2a2a165e40..7fd28969c4 100644 --- a/rdmo/core/imports.py +++ b/rdmo/core/imports.py @@ -138,11 +138,6 @@ def track_changes_on_element(element: dict, lookup_field = element_field if instance_field is None else instance_field original_value = getattr(original, lookup_field, '') - if isinstance(new_value,str) and isinstance(original_value,int): - # typecasting of original value to str, for comparison '0' == 0 - # specific edge-case, maybe generalize later - original_value = str(original_value) - element[ImportElementFields.DIFF][element_field][ImportElementFields.CURRENT] = original_value element[ImportElementFields.DIFF][element_field][ImportElementFields.NEW] = new_value @@ -207,7 +202,7 @@ def track_changes_on_uri_of_foreign_field(element, field_name, foreign_uri, orig track_changes_on_element(element, field_name, new_value=foreign_uri, original_value=original_foreign_uri) -def set_foreign_field(instance, field_name, element, uploaded_uris=None, original=None) -> None: +def set_foreign_field(instance, field_name, element, original=None) -> None: if field_name not in element: return @@ -278,45 +273,46 @@ def set_foreign_field(instance, field_name, element, uploaded_uris=None, origina def set_extra_field(instance, field_name, element, - extra_field_helper: Optional[ExtraFieldHelper] = None, original=None) -> None: + extra_field_helper: Optional[ExtraFieldHelper] = None, + ) -> None: - extra_value = None + extra_field_value = None if field_name in element: - extra_value = element.get(field_name) + extra_field_value = element.get(field_name) else: # get the default field value from the instance instance_value = getattr(instance, field_name) element[field_name] = instance_value - extra_value = instance_value + extra_field_value = instance_value if extra_field_helper is not None: # default_value extra_value_from_helper = extra_field_helper.get_value(instance=instance, key=field_name) # overwrite None or '' values by the get_value from the helper - if extra_value is None or extra_value == '': - extra_value = extra_value_from_helper + if extra_field_value is None or extra_field_value == '': + extra_field_value = extra_value_from_helper if extra_field_helper.overwrite_in_element: - element[field_name] = extra_value - - if extra_value is not None: - setattr(instance, field_name, extra_value) - # track changes - track_changes_on_element(element, field_name, new_value=extra_value, original=original) + element[field_name] = extra_field_value + if extra_field_value is not None: + setattr(instance, field_name, extra_field_value) def track_changes_m2m_instances(element, field_name, foreign_instances, original=None): if original is None: return original_m2m_instance = getattr(original, field_name) - if original_m2m_instance is None: - return - original_m2m_uris = list(original_m2m_instance.values_list('uri', flat=True)) - foreign_uris = [i.uri for i in foreign_instances] - track_changes_on_element(element, field_name, new_value=foreign_uris, - original_value=original_m2m_uris) + original_m2m_instance = original_m2m_instance or [] + # m2m instance fields are unordered so comparison by set + original_uris = set(original_m2m_instance.values_list('uri', flat=True)) + foreign_uris = {i.uri for i in foreign_instances} + common_uris = list(original_uris & foreign_uris) + original_uris_list = common_uris + list(original_uris - foreign_uris) + foreign_uris_list = common_uris + list(foreign_uris - original_uris) + track_changes_on_element(element, field_name, new_value=foreign_uris_list, + original_value=original_uris_list) def set_m2m_through_instances(instance, element, field_name=None, source_name=None, @@ -354,6 +350,7 @@ def set_m2m_through_instances(instance, element, field_name=None, source_name=No for target_element in target_elements: target_uri = target_element.get('uri') + target_element['order'] = int(target_element['order']) # cast to int for ordering order = target_element.get('order') new_data.append(target_element) @@ -497,6 +494,7 @@ def set_reverse_m2m_through_instance(instance, element, field_name=None, source_ for target_element in target_elements: target_uri = target_element.get('uri') + target_element['order'] = int(target_element['order']) # cast to int for ordering order = target_element.get('order') new_data.append(target_element) diff --git a/rdmo/core/xml.py b/rdmo/core/xml.py index 175b9c25c4..264977352e 100644 --- a/rdmo/core/xml.py +++ b/rdmo/core/xml.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -DEFAULT_RDMO_XML_VERSION = '1.11.0' +LEGACY_RDMO_XML_VERSION = '1.11.0' ELEMENTS_USING_KEY = {RDMO_MODELS['attribute']} @@ -45,7 +45,7 @@ def validate_root(root: Optional[xmlElement]) -> Tuple[bool, Optional[str]]: def validate_and_get_xml_version_from_root(root: xmlElement) -> Tuple[Optional[Version], list]: - unparsed_root_version = root.attrib.get('version') or DEFAULT_RDMO_XML_VERSION + unparsed_root_version = root.attrib.get('version') or LEGACY_RDMO_XML_VERSION root_version, rdmo_version = parse(unparsed_root_version), parse(RDMO_INSTANCE_VERSION) if root_version > rdmo_version: logger.info('Import failed version validation (%s > %s)', root_version, rdmo_version) @@ -117,13 +117,13 @@ def parse_xml_to_elements(xml_file=None) -> Tuple[OrderedDict, list]: return OrderedDict(), errors # step 3.1.1: validate the legacy elements - legacy_errors = validate_legacy_elements(elements, parse(root.attrib.get('version', DEFAULT_RDMO_XML_VERSION))) + legacy_errors = validate_legacy_elements(elements, parse(root.attrib.get('version', LEGACY_RDMO_XML_VERSION))) if legacy_errors: errors.extend(legacy_errors) return OrderedDict(), errors # step 4: convert elements from previous versions - elements = convert_elements(elements, parse(root.attrib.get('version', DEFAULT_RDMO_XML_VERSION))) + elements = convert_elements(elements, parse(root.attrib.get('version', LEGACY_RDMO_XML_VERSION))) # step 5: order the elements and return # ordering of elements is done in the import_elements function @@ -257,6 +257,22 @@ def validate_pre_conversion_for_missing_key_in_legacy_elements(elements, version raise ValueError(f"Missing legacy elements, elements containing 'key' were expected for this XML with version {version} and elements {models_in_elements}.") # noqa: E501 +def update_related_legacy_elements(elements: Dict, + target_uri: str, source_model: str, + legacy_element_field: str, element_field: str): + # search for the related elements that use the uri + related_elements = [ + element for element in elements.values() + if element['model'] == source_model + and element.get(legacy_element_field, {}).get('uri') == target_uri + ] + # write the related elements back into the related element + elements[target_uri][element_field] = [ + {k: v for k, v in element.items() if k in ('uri', 'model', 'order')} + for element in related_elements + ] + + def convert_legacy_elements(elements): # first pass: identify pages for uri, element in elements.items(): @@ -275,20 +291,30 @@ def convert_legacy_elements(elements): elif element['model'] == 'questions.catalog': element['uri_path'] = element.pop('key') + # Add sections to the catalog + update_related_legacy_elements(elements, uri, 'questions.section', 'catalog', 'sections') elif element['model'] == 'questions.section': del element['key'] element['uri_path'] = element.pop('path') - - if element.get('catalog') is not None: - element['catalog']['order'] = element.pop('order') + del element['catalog'] # sections do not have catalog anymore + # Add section_pages to the section + update_related_legacy_elements(elements, uri, 'questions.page', 'section', 'pages') elif element['model'] == 'questions.page': del element['key'] element['uri_path'] = element.pop('path') + del element['section'] # pages do not have sections anymore - if element.get('section') is not None: - element['section']['order'] = element.pop('order') + # Add page_questionsets to the page + # Add questionsets to the page + update_related_legacy_elements(elements, uri, 'questions.questionset', 'questionset', 'questionsets') + + # Add page_questions to the page + update_related_legacy_elements(elements, uri, 'questions.question', 'question', 'questions') + + # Add page_conditions to the page + update_related_legacy_elements(elements, uri, 'conditions.condition', 'condition', 'conditions') elif element['model'] == 'questions.questionset': del element['key'] @@ -298,11 +324,15 @@ def convert_legacy_elements(elements): if parent is not None: if elements[parent].get('model') == 'questions.page': # this questionset belongs to a page now - del element['questionset'] - element['page'] = { - 'uri': parent, + parent_questionsets = elements[parent].get('questionset') + parent_questionsets = parent_questionsets or [] + parent_questionsets.append({ + 'uri': element['uri'], + 'model': element['model'], 'order': element.pop('order') - } + }) + elements[parent]['questionset'] = parent_questionsets + del element['questionset'] else: # this questionset still belongs to a questionset element['questionset']['order'] = element.pop('order') @@ -313,26 +343,26 @@ def convert_legacy_elements(elements): parent = element.get('questionset').get('uri') if parent is not None: - if elements[parent].get('model') == 'questions.page': - # this question belongs to a page now - del element['questionset'] - element['page'] = { - 'uri': parent, - 'order': element.pop('order') - } - else: - # this question still belongs to a questionset - element['questionset']['order'] = element.pop('order') + parent_questionsets = elements[parent].get('questions', []) + parent_questionsets.append({ + 'uri': element['uri'], + 'model': element['model'], + 'order': element.pop('order') + }) + elements[parent]['questions'] = parent_questionsets + del element['questionset'] elif element['model'] == 'options.optionset': element['uri_path'] = element.pop('key') + update_related_legacy_elements(elements, uri, 'options.option', 'optionset', 'options') + elif element['model'] == 'options.option': del element['key'] element['uri_path'] = element.pop('path') - if element.get('optionset') is not None: - element['optionset']['order'] = element.pop('order') + del element['optionset'] # options do not have optionsets anymore + if element['model'] == 'tasks.task': element['uri_path'] = element.pop('key') @@ -356,59 +386,21 @@ def convert_additional_input(elements): return elements -def get_related_elements(element, ignored_keys=None): - ignored_keys = ignored_keys or list(ImportElementFields) - related_elements = {k: val for k, val in element.items() if - k not in ignored_keys and (isinstance(val, (dict, list)))} - return related_elements - - -def sort_by_relatives(elements, descendants_first=False, ancestors_first=False): - ancestors, descendants = [], [] - if not descendants_first and not ancestors_first: - return elements - - for uri, element in elements.items(): - try: - has_descendants = get_related_elements(element) - except AttributeError: - has_descendants = False - if has_descendants: - ancestors.append((uri, element)) - else: - descendants.append((uri, element)) - if descendants_first: - sort_list = descendants + ancestors - elif ancestors_first: - sort_list = ancestors + descendants - - sorted_elements = OrderedDict() - for uri,element in sort_list: - sorted_elements[uri] = element - return sorted_elements - - -def order_elements(elements, order_sets_first=False, descendants_first=False) -> OrderedDict: +def order_elements(elements: OrderedDict) -> OrderedDict: ordered_elements = OrderedDict() - if descendants_first: - elements = sort_by_relatives(elements, descendants_first=descendants_first) - for uri, element in elements.items(): - append_element(ordered_elements, elements, uri, element, order_sets_first=order_sets_first) + for uri, element in reversed(elements.items()): + append_element(ordered_elements, elements, uri, element,) return ordered_elements -def append_element(ordered_elements, unordered_elements, uri, element, order_sets_first=False) -> None: +def append_element(ordered_elements, unordered_elements, uri, element) -> None: if element is None: return + for key, element_value in element.items(): + if key in list(ImportElementFields): + continue - related_elements = get_related_elements(element) - - if order_sets_first: - if related_elements and uri not in ordered_elements: - ordered_elements[uri] = element - - for key, element_value in related_elements.items(): if isinstance(element_value, dict): sub_uri = element_value.get('uri') sub_element = unordered_elements.get(sub_uri) @@ -417,11 +409,10 @@ def append_element(ordered_elements, unordered_elements, uri, element, order_set elif isinstance(element_value, list): for value in element_value: - if isinstance(element_value, dict): - sub_uri = value.get('uri') - sub_element = unordered_elements.get(sub_uri) - if sub_uri not in ordered_elements and sub_uri is not None: - append_element(ordered_elements, unordered_elements, sub_uri, sub_element) + sub_uri = value.get('uri') + sub_element = unordered_elements.get(sub_uri) + if sub_uri not in ordered_elements and sub_uri is not None: + append_element(ordered_elements, unordered_elements, sub_uri, sub_element) if uri not in ordered_elements: ordered_elements[uri] = element diff --git a/rdmo/domain/imports.py b/rdmo/domain/imports.py index cf4aefed15..32627133a0 100644 --- a/rdmo/domain/imports.py +++ b/rdmo/domain/imports.py @@ -18,7 +18,6 @@ def build_attribute_uri(instance: Optional[Attribute]=None): return instance.build_uri(instance.uri_prefix, instance.path) - import_helper_attribute = ElementImportHelper( model=Attribute, common_fields=('uri_prefix', 'key', 'comment'), diff --git a/rdmo/management/assets/js/components/import/common/WarningsListGroup.js b/rdmo/management/assets/js/components/import/common/WarningsListGroup.js index 00c6e3d99f..efcfb665f5 100644 --- a/rdmo/management/assets/js/components/import/common/WarningsListGroup.js +++ b/rdmo/management/assets/js/components/import/common/WarningsListGroup.js @@ -37,8 +37,4 @@ WarningsListGroup.propTypes = { shouldShowURI: PropTypes.bool } -WarningsListGroup.defaultProps = { - shouldShowURI: true, -} - export default WarningsListGroup diff --git a/rdmo/management/import_utils.py b/rdmo/management/import_utils.py index 92ec4cfa5e..792b1a3870 100644 --- a/rdmo/management/import_utils.py +++ b/rdmo/management/import_utils.py @@ -1,6 +1,12 @@ +import logging from collections import defaultdict from dataclasses import asdict -from typing import Dict +from typing import Dict, Set, Tuple + +from django.db.models import Model + +from rest_framework.exceptions import ValidationError +from rest_framework.serializers import ModelSerializer from rdmo.core.imports import ( ImportElementFields, @@ -11,8 +17,11 @@ set_m2m_instances, set_m2m_through_instances, set_reverse_m2m_through_instance, + track_changes_on_element, ) +logger = logging.getLogger(__name__) + IMPORT_ELEMENT_INIT_DICT = { ImportElementFields.WARNINGS: lambda: defaultdict(list), ImportElementFields.ERRORS: list, @@ -22,10 +31,40 @@ } +def is_valid_import_element(element: dict) -> bool: + if element is None or not isinstance(element, dict): + return False + if not all(i in element for i in ['model', 'uri']): + return False + return True + + +def get_redundant_keys_from_element(element_keys: Set, model: Model) -> Set: + model_fields = {i.name for i in model._meta.get_fields()} + required_element_keys = {'uri', 'model'} + import_dict_keys = {i.value for i in IMPORT_ELEMENT_INIT_DICT.keys()} + redundant_keys = element_keys - model_fields - required_element_keys - import_dict_keys + + lang_fields_prefix = {i.split('_lang')[0] for i in model_fields if 'lang' in i} + element_lang_keys = {i for i in element_keys if any(i.startswith(a) for a in lang_fields_prefix)} + redundant_keys = redundant_keys - element_lang_keys + return redundant_keys + def initialize_import_element_dict(element: Dict) -> None: # initialize element dict with default values for _k,_val in IMPORT_ELEMENT_INIT_DICT.items(): element[_k] = _val() + return element + + +def initialize_and_clean_import_element_dict(element: Dict, model: Model) -> Tuple[Dict, Dict]: + redundant_keys = get_redundant_keys_from_element(set(element.keys()), model) + excluded_element_data = {} + for k in redundant_keys: + excluded_element_data[k] = element.pop(k) + # initialize element dict with default values + element = initialize_import_element_dict(element) + return element, excluded_element_data def strip_uri_prefix_endswith_slash(element: dict) -> dict: @@ -35,7 +74,7 @@ def strip_uri_prefix_endswith_slash(element: dict) -> dict: return element -def apply_field_values(instance, element, import_helper, uploaded_uris, original) -> None: +def apply_field_values(instance, element, import_helper, original) -> None: """Applies the field values from the element to the instance.""" # start to set values on the instance # set common field values from element on instance @@ -46,10 +85,43 @@ def apply_field_values(instance, element, import_helper, uploaded_uris, original set_lang_field(instance, field, element, original=original) # set foreign fields for field in import_helper.foreign_fields: - set_foreign_field(instance, field, element, uploaded_uris=uploaded_uris, original=original) + set_foreign_field(instance, field, element, original=original) + # set extra fields, track changes is done after instance.full_clean + for extra_field in import_helper.extra_fields: + set_extra_field(instance, extra_field.field_name, element, + extra_field_helper=extra_field) + + +def validate_with_serializer_field(instance, field_name, value): + """Validate and convert a value using the corresponding DRF serializer field.""" + model_field = instance._meta.get_field(field_name) + drf_field_class = ModelSerializer.serializer_field_mapping.get(type(model_field)) + + if drf_field_class is not None: + try: + drf_field = drf_field_class() + return drf_field.to_internal_value(value) + except (ValidationError, TypeError, ValueError) as e: + logger.debug("Cannot convert '%s' for field '%s' using '%s': %s", + value, field_name, drf_field_class.__name__, str(e)) + return None + + +def update_extra_fields_from_validated_instance(instance, element, import_helper, original=None) -> None: for extra_field in import_helper.extra_fields: - set_extra_field(instance, extra_field.field_name, element, extra_field_helper=extra_field, original=original) + field_name = extra_field.field_name + + element_field_value = element.get(field_name) + + # deserialize the element value by using the drf field serializer + validated_value = validate_with_serializer_field(instance, field_name, element_field_value) + + if validated_value is not None: + element[field_name] = validated_value + + # track changes + track_changes_on_element(element, field_name, new_value=element[field_name], original=original) def update_related_fields(instance, element, import_helper, original, save) -> None: diff --git a/rdmo/management/imports.py b/rdmo/management/imports.py index 10679d100a..3e483d24aa 100644 --- a/rdmo/management/imports.py +++ b/rdmo/management/imports.py @@ -1,6 +1,7 @@ import copy import logging -from typing import AbstractSet, Dict, List, Optional +from collections import OrderedDict +from typing import Dict, List, Optional from django.conf import settings from django.contrib.sites.shortcuts import get_current_site @@ -8,6 +9,7 @@ from rdmo.conditions.imports import import_helper_condition from rdmo.core.imports import ( + ImportElementFields, check_permissions, get_or_return_instance, make_import_info_msg, @@ -18,8 +20,11 @@ from rdmo.management.import_utils import ( add_current_site_to_sites_and_editor, apply_field_values, + initialize_and_clean_import_element_dict, initialize_import_element_dict, + is_valid_import_element, strip_uri_prefix_endswith_slash, + update_extra_fields_from_validated_instance, update_related_fields, ) from rdmo.options.imports import import_helper_option, import_helper_optionset @@ -36,41 +41,54 @@ logger = logging.getLogger(__name__) ELEMENT_IMPORT_HELPERS = { - "conditions.condition": import_helper_condition, "domain.attribute": import_helper_attribute, - "options.optionset": import_helper_optionset, "options.option": import_helper_option, - "questions.catalog": import_helper_catalog, + "conditions.condition": import_helper_condition, + "options.optionset": import_helper_optionset, + "questions.question": import_helper_question, + "questions.questionset": import_helper_questionset, "questions.section": import_helper_section, "questions.page": import_helper_page, - "questions.questionset": import_helper_questionset, - "questions.question": import_helper_question, + "questions.catalog": import_helper_catalog, "tasks.task": import_helper_task, "views.view": import_helper_view } -def import_elements(uploaded_elements: Dict, save: bool = True, request: Optional[HttpRequest] = None) -> List[Dict]: +def import_elements(uploaded_elements: OrderedDict, + save: bool = True, + request: Optional[HttpRequest] = None) -> List[Dict]: imported_elements = [] - uploaded_elements_ordering_index = {uri: n for n, uri in enumerate(uploaded_elements.keys())} + uploaded_elements_initial_ordering = {uri: n for n, uri in enumerate(uploaded_elements.keys())} uploaded_uris = set(uploaded_elements.keys()) current_site = get_current_site(request) if save: - # when saving, the descendant elements go first - uploaded_elements = order_elements(uploaded_elements, descendants_first=True) + # when saving, the elements are ordered according to the rdmo models + pass + uploaded_elements = order_elements(uploaded_elements) for _uri, uploaded_element in uploaded_elements.items(): - element = import_element(element=uploaded_element, - save=save, - uploaded_uris=uploaded_uris, - request=request, - current_site=current_site) - element['warnings'] = {k: val for k, val in element['warnings'].items() if k not in uploaded_uris} + if not is_valid_import_element(uploaded_element): + continue + element = import_element( + element=uploaded_element, + save=save, + request=request, + current_site=current_site + ) + element[ImportElementFields.WARNINGS] = { + k: val for + k, val in element[ImportElementFields.WARNINGS].items() + if k not in uploaded_uris + } imported_elements.append(element) - # sort elements back to order of uploaded elements - imported_elements = sorted(imported_elements, - key=lambda x: uploaded_elements_ordering_index.get(x['uri'], float('inf'))) + # sort elements back to initial order of uploaded elements + imported_elements = sorted( + imported_elements, + key=lambda x: uploaded_elements_initial_ordering.get(x['uri'], + float('inf')) + ) return imported_elements @@ -79,20 +97,17 @@ def import_element( element: Optional[Dict] = None, save: bool = True, request: Optional[HttpRequest] = None, - uploaded_uris: Optional[AbstractSet[str]] = None, current_site = None ) -> Dict: - if element is None or not isinstance(element, dict): - return {} - if 'model' not in element: - return {} - initialize_import_element_dict(element) import_helper = ELEMENT_IMPORT_HELPERS[element['model']] uri = element.get('uri') + + element, _excluded_data = initialize_and_clean_import_element_dict(element, import_helper.model) + # get or create instance from uri and model instance, created = get_or_return_instance(import_helper.model, uri=uri) @@ -109,25 +124,27 @@ def import_element( perms_error_msg = check_permissions(instance, uri, user) if perms_error_msg: # when there is an error msg, the import can be stopped and return - element["errors"].append(perms_error_msg) + element[ImportElementFields.ERRORS].append(perms_error_msg) return element - element['created'] = created - element['updated'] = not created and original is not None + element[ImportElementFields.CREATED] = created + element[ImportElementFields.UPDATED] = not created and original is not None # INFO: the dict element[FieldNames.diff.value] is filled by calling track_changes_on_element element = strip_uri_prefix_endswith_slash(element) # start to set values on the instance - apply_field_values(instance, element, import_helper, uploaded_uris, original) + apply_field_values(instance, element, import_helper, original) # call the validators on the instance validate_instance(instance, element, *import_helper.validators) - if element.get('errors'): + update_extra_fields_from_validated_instance(instance, element, import_helper, original=original) + + if element.get(ImportElementFields.ERRORS): # when there is an error msg, the import can be stopped and return if save: - element['created'] = False - element['updated'] = False + element[ImportElementFields.CREATED] = False + element[ImportElementFields.UPDATED] = False return element if save: diff --git a/rdmo/management/tests/helpers_import_elements.py b/rdmo/management/tests/helpers_import_elements.py index a1102c14db..d50b923952 100644 --- a/rdmo/management/tests/helpers_import_elements.py +++ b/rdmo/management/tests/helpers_import_elements.py @@ -1,3 +1,4 @@ +import random from collections import OrderedDict from functools import partial from typing import Dict, List, Optional, Tuple, Union @@ -72,7 +73,13 @@ def _test_helper_change_fields_elements(elements, return _new_elements -def parse_xml_and_import_elements(xml_file): +def parse_xml_and_import_elements(xml_file, shuffle_elements=False): elements, root = read_xml_and_parse_to_root_and_elements(xml_file) + if shuffle_elements: + # Extract items from the OrderedDict + items = list(elements.items()) + # Shuffle the list of items + random.shuffle(items) + elements = OrderedDict(items) imported_elements = import_elements(elements) return elements, root, imported_elements diff --git a/rdmo/management/tests/test_import_options.py b/rdmo/management/tests/test_import_options.py index d1f54b156b..1af5787096 100644 --- a/rdmo/management/tests/test_import_options.py +++ b/rdmo/management/tests/test_import_options.py @@ -28,6 +28,28 @@ }, } +OPTIONSET_URIS = { + "http://example.com/terms/options/condition": [ + "http://example.com/terms/options/condition/other" + ], + "http://example.com/terms/options/one_two_three": [ + "http://example.com/terms/options/one_two_three/one", + "http://example.com/terms/options/one_two_three/two", + "http://example.com/terms/options/one_two_three/three", + ], + "http://example.com/terms/options/one_two_three_other": [ + "http://example.com/terms/options/one_two_three_other/one", + "http://example.com/terms/options/one_two_three_other/two", + "http://example.com/terms/options/one_two_three_other/three", + "http://example.com/terms/options/one_two_three_other/text", + "http://example.com/terms/options/one_two_three_other/textarea" + ], + "http://example.com/terms/options/plugin": [] +} +LEGACY_SKIP_URIS = [ + "http://example.com/terms/options/one_two_three_other/textarea" +] + def test_create_optionsets(db, settings): delete_all_objects([OptionSet, Option]) @@ -40,7 +62,14 @@ def test_create_optionsets(db, settings): assert Option.objects.count() == 9 assert all(element[ImportElementFields.CREATED] is True for element in imported_elements) assert all(element[ImportElementFields.UPDATED] is False for element in imported_elements) - + for optionset_uri, options_uris in OPTIONSET_URIS.items(): + db_optionset = OptionSet.objects.get(uri=optionset_uri) + db_options = Option.objects.filter(uri__in=options_uris) + db_options_uris = db_options.values_list('uri', flat=True) + assert set(db_options_uris) == set(options_uris) + db_ordered_options_uris = db_optionset.options.filter(uri__in=options_uris).order_by( + 'option_optionsets__order').values_list('uri',flat=True) + assert options_uris == list(db_ordered_options_uris) def test_update_optionsets(db, settings): xml_file = Path(settings.BASE_DIR) / 'xml' / 'elements' / 'optionsets.xml' @@ -189,6 +218,16 @@ def test_create_legacy_options(db, settings): assert Option.objects.count() == 8 assert all(element[ImportElementFields.CREATED] is True for element in imported_elements) assert all(element[ImportElementFields.UPDATED] is False for element in imported_elements) + for optionset_uri, test_options_uris in OPTIONSET_URIS.items(): + # legacy has no "http://example.com/terms/options/one_two_three_other/textarea" + options_uris = [i for i in test_options_uris if i not in LEGACY_SKIP_URIS] + db_optionset = OptionSet.objects.get(uri=optionset_uri) + db_options = Option.objects.filter(uri__in=options_uris) + db_options_uris = db_options.values_list('uri', flat=True) + assert set(db_options_uris) == set(options_uris) + db_ordered_options_uris = db_optionset.options.filter(uri__in=options_uris).order_by( + 'option_optionsets__order').values_list('uri',flat=True) + assert options_uris == list(db_ordered_options_uris) def test_update_legacy_options(db, settings): diff --git a/rdmo/management/tests/test_import_questions.py b/rdmo/management/tests/test_import_questions.py index 6145c3fce4..fadf5127d6 100644 --- a/rdmo/management/tests/test_import_questions.py +++ b/rdmo/management/tests/test_import_questions.py @@ -15,13 +15,23 @@ fields_to_be_changed = (('comment',),) - -def test_create_catalogs(db, settings): +TEST_CATALOG_SECTIONS_URIS = { + "http://example.com/terms/questions/catalog/individual", + "http://example.com/terms/questions/catalog/collections", + "http://example.com/terms/questions/catalog/set", + "http://example.com/terms/questions/catalog/conditions", + "http://example.com/terms/questions/catalog/options", + "http://example.com/terms/questions/catalog/blocks" +} + + +@pytest.mark.parametrize('shuffle', [True, False]) +def test_create_catalogs(db, settings, shuffle): delete_all_objects([Catalog, Section, Page, QuestionSet, Question]) xml_file = Path(settings.BASE_DIR) / 'xml' / 'elements' / 'catalogs.xml' - elements, root, imported_elements = parse_xml_and_import_elements(xml_file) + elements, root, imported_elements = parse_xml_and_import_elements(xml_file, shuffle_elements=shuffle) assert len(root) == len(imported_elements) == 148 assert Catalog.objects.count() == 2 @@ -32,6 +42,18 @@ def test_create_catalogs(db, settings): assert all(element['created'] is True for element in imported_elements) assert all(element['updated'] is False for element in imported_elements) + # check that all elements ended up in the catalog + catalog = Catalog.objects.prefetch_elements().get(uri="http://example.com/terms/questions/catalog") + catalog_sections = catalog.sections.all() + catalog_sections_uris = set(catalog_sections.values_list('uri', flat=True)) + assert catalog_sections_uris == TEST_CATALOG_SECTIONS_URIS + + sections_pages = Section.objects.filter(uri__in=catalog_sections_uris).values_list('pages') + assert sections_pages.distinct().count() == 48 + sections_pages_questionsets = Page.objects.filter(id__in=sections_pages).values_list('questionsets') + assert sections_pages_questionsets.distinct().count() == 3 + sections_pages_questions = Page.objects.filter(id__in=sections_pages).values_list('questions') + assert sections_pages_questions.distinct().count() == 85 def test_update_catalogs(db, settings): xml_file = Path(settings.BASE_DIR) / 'xml' / 'elements' / 'catalogs.xml' @@ -39,6 +61,7 @@ def test_update_catalogs(db, settings): elements, root, imported_elements = parse_xml_and_import_elements(xml_file) assert len(root) == len(imported_elements) == 148 + assert all(element['created'] is False for element in imported_elements) assert all(element['updated'] is True for element in imported_elements) @@ -205,12 +228,13 @@ def test_update_questionsets_with_changed_fields(db, settings, updated_fields): assert test[ImportElementFields.DIFF] == imported[ImportElementFields.DIFF] -def test_create_questions(db, settings): +@pytest.mark.parametrize('shuffle', [True, False]) +def test_create_questions(db, settings, shuffle): delete_all_objects([Page, QuestionSet, Question]) xml_file = Path(settings.BASE_DIR) / 'xml' / 'elements' / 'questions.xml' - elements, root, imported_elements = parse_xml_and_import_elements(xml_file) + elements, root, imported_elements = parse_xml_and_import_elements(xml_file, shuffle_elements=shuffle) assert len(root) == len(imported_elements) == 89 assert Question.objects.count() == 89 @@ -249,12 +273,13 @@ def test_update_questions_with_changed_fields(db, settings, updated_fields): assert test[ImportElementFields.DIFF] == imported[ImportElementFields.DIFF] -def test_create_legacy_questions(db, settings): +@pytest.mark.parametrize('shuffle', [True, False]) +def test_create_legacy_questions(db, settings, shuffle): delete_all_objects([Catalog, Section, Page, QuestionSet, Question]) xml_file = Path(settings.BASE_DIR) / 'xml' / 'elements' / 'legacy' / 'questions.xml' - elements, root, imported_elements = parse_xml_and_import_elements(xml_file) + elements, root, imported_elements = parse_xml_and_import_elements(xml_file, shuffle_elements=shuffle) assert len(root) == len(imported_elements) == 147 assert Catalog.objects.count() == 1 @@ -267,9 +292,15 @@ def test_create_legacy_questions(db, settings): # check that all elements ended up in the catalog catalog = Catalog.objects.prefetch_elements().first() - descendant_uris = {element.uri for element in catalog.descendants} - element_uris = {element['uri'] for _uri, element in elements.items() if element['uri'] != catalog.uri} - assert descendant_uris == element_uris + catalog_sections = catalog.sections.all() + catalog_sections_uris = set(catalog_sections.values_list('uri', flat=True)) + assert catalog_sections_uris == TEST_CATALOG_SECTIONS_URIS + sections_pages = Section.objects.filter(uri__in=catalog_sections_uris).values_list('pages') + assert sections_pages.distinct().count() == 48 + sections_pages_questionsets = Page.objects.filter(id__in=sections_pages).values_list('questionsets') + assert sections_pages_questionsets.distinct().count() == 3 + sections_pages_questions = Page.objects.filter(id__in=sections_pages).values_list('questions') + assert sections_pages_questions.distinct().count() == 85 def test_update_legacy_questions(db, settings): @@ -283,15 +314,6 @@ def test_update_legacy_questions(db, settings): # check that all elements ended up in the catalog catalog = Catalog.objects.prefetch_elements().first() - descendant_uris = { - element.uri for element in catalog.descendants if any(element.uri.startswith(uri) for uri in [ - 'http://example.com/terms/questions/catalog/individual', - 'http://example.com/terms/questions/catalog/set', - 'http://example.com/terms/questions/catalog/collections', - 'http://example.com/terms/questions/catalog/conditions', - 'http://example.com/terms/questions/catalog/options', - 'http://example.com/terms/questions/catalog/blocks' - ]) - } - element_uris = {element['uri'] for _uri, element in elements.items() if element['uri'] != catalog.uri} - assert descendant_uris == element_uris + catalog_sections = catalog.sections.all() + catalog_sections_uris = set(catalog_sections.values_list('uri', flat=True)) + assert catalog_sections_uris == TEST_CATALOG_SECTIONS_URIS diff --git a/rdmo/management/viewsets.py b/rdmo/management/viewsets.py index 3bad097651..2b5ec2d295 100644 --- a/rdmo/management/viewsets.py +++ b/rdmo/management/viewsets.py @@ -72,8 +72,7 @@ def create(self, request, *args, **kwargs): # step 1: store xml file as tmp file try: elements_data = request.data['elements'] - _elements = filter(lambda x: 'uri' in x, elements_data) - elements = {i['uri']: i for i in _elements} + elements = {i['uri']: i for i in elements_data if 'uri' in i} except KeyError as e: raise ValidationError({'elements': [_('This field may not be blank.')]}) from e except TypeError as e: diff --git a/rdmo/questions/imports.py b/rdmo/questions/imports.py index 93f9eeff4b..c923f4e78f 100644 --- a/rdmo/questions/imports.py +++ b/rdmo/questions/imports.py @@ -17,10 +17,10 @@ ) import_helper_catalog = ElementImportHelper( - model = Catalog, + model=Catalog, validators=(CatalogLockedValidator, CatalogUniqueURIValidator), - lang_fields=('help', 'title'), - extra_fields = ( + lang_fields=('title', 'help'), + extra_fields=( ExtraFieldHelper(field_name='order'), ExtraFieldHelper(field_name='available', overwrite_in_element=True), ), @@ -34,7 +34,7 @@ ) import_helper_section = ElementImportHelper( - model = Section, + model=Section, validators=(SectionLockedValidator, SectionUniqueURIValidator), lang_fields=('title', 'short_title'), m2m_through_instance_fields=[ @@ -52,14 +52,14 @@ ) import_helper_page = ElementImportHelper( - model = Page, + model=Page, validators=(PageLockedValidator, PageUniqueURIValidator), lang_fields=('help', 'title', 'verbose_name', 'short_title'), foreign_fields=('attribute',), - extra_fields = ( + extra_fields=( ExtraFieldHelper(field_name='is_collection'), ), - m2m_instance_fields = ('conditions', ), + m2m_instance_fields=('conditions', ), m2m_through_instance_fields=[ ThroughInstanceMapper( field_name='questionsets', source_name='page', @@ -78,39 +78,10 @@ ] ) -import_helper_question = ElementImportHelper( - model=Question, - validators=(QuestionLockedValidator, QuestionUniqueURIValidator), - lang_fields=('text', 'help', 'default_text', 'verbose_name'), - foreign_fields=('attribute', 'default_option'), - extra_fields=( - ExtraFieldHelper(field_name='is_collection'), - ExtraFieldHelper(field_name='is_optional'), - ExtraFieldHelper(field_name='default_external_id', value=''), - ExtraFieldHelper(field_name='widget_type', callback=get_widget_type_or_default), - ExtraFieldHelper(field_name='value_type', value=VALUE_TYPE_TEXT), - ExtraFieldHelper(field_name='minimum'), - ExtraFieldHelper(field_name='maximum'), - ExtraFieldHelper(field_name='step'), - ExtraFieldHelper(field_name='unit', value=''), - ExtraFieldHelper(field_name='width'), - ), - m2m_instance_fields=('conditions', 'optionsets'), - reverse_m2m_through_instance_fields=[ - ThroughInstanceMapper( - field_name='page', source_name='question', - target_name='page', through_name='question_pages' - ), - ThroughInstanceMapper( - field_name='questionset', source_name='question', - target_name='questionset', through_name='question_questionsets' - ) - ] -) import_helper_questionset = ElementImportHelper( - model = QuestionSet, + model=QuestionSet, validators=(QuestionSetLockedValidator, QuestionSetUniqueURIValidator), - lang_fields=('help', 'title', 'verbose_name'), + lang_fields=( 'title', 'help', 'verbose_name'), foreign_fields=('attribute',), extra_fields=( ExtraFieldHelper(field_name='is_collection'), @@ -138,3 +109,33 @@ ) ] ) + +import_helper_question = ElementImportHelper( + model=Question, + validators=(QuestionLockedValidator, QuestionUniqueURIValidator), + lang_fields=('text', 'help', 'default_text', 'verbose_name'), + foreign_fields=('attribute', 'default_option'), + extra_fields=( + ExtraFieldHelper(field_name='is_collection'), + ExtraFieldHelper(field_name='is_optional'), + ExtraFieldHelper(field_name='default_external_id', value=''), + ExtraFieldHelper(field_name='widget_type', callback=get_widget_type_or_default), + ExtraFieldHelper(field_name='value_type', value=VALUE_TYPE_TEXT), + ExtraFieldHelper(field_name='minimum'), + ExtraFieldHelper(field_name='maximum'), + ExtraFieldHelper(field_name='step'), + ExtraFieldHelper(field_name='unit', value=''), + ExtraFieldHelper(field_name='width'), + ), + m2m_instance_fields=('conditions', 'optionsets'), + reverse_m2m_through_instance_fields=[ + ThroughInstanceMapper( + field_name='page', source_name='question', + target_name='page', through_name='question_pages' + ), + ThroughInstanceMapper( + field_name='questionset', source_name='question', + target_name='questionset', through_name='question_questionsets' + ) + ] +)