diff --git a/specifyweb/specify/tree_utils.py b/specifyweb/specify/tree_utils.py index af442bd6787..a7663b626a0 100644 --- a/specifyweb/specify/tree_utils.py +++ b/specifyweb/specify/tree_utils.py @@ -20,13 +20,13 @@ def get_search_filters(collection: spmodels.Collection, tree: str): discipline_query |= Q(id=tree_at_discipline.id) return discipline_query -def get_treedefs(collection: spmodels.Collection, tree_name: str) -> List[Tuple[int, int]]: +def get_treedefs(collection: spmodels.Collection, tree_name: str, treedef_id=None) -> List[Tuple[int, int]]: # Get the appropriate TreeDef based on the Collection and tree_name # Mimic the old behavior of limiting the query to the first item for trees other than taxon. # Even though the queryconstruct can handle trees with multiple types. _limit = lambda query: (query if tree_name.lower() == 'taxon' else query[:1]) - search_filters = get_search_filters(collection, tree_name) + search_filters = get_search_filters(collection, tree_name) if treedef_id is None else Q(id=treedef_id) lookup_tree = lookup(tree_name) tree_table = datamodel.get_table_strict(lookup_tree) @@ -45,4 +45,3 @@ def get_treedefs(collection: spmodels.Collection, tree_name: str) -> List[Tuple assert len(result) > 0, "No definition to query on" return result - diff --git a/specifyweb/stored_queries/query_construct.py b/specifyweb/stored_queries/query_construct.py index b63fb900785..86bb737e95c 100644 --- a/specifyweb/stored_queries/query_construct.py +++ b/specifyweb/stored_queries/query_construct.py @@ -17,7 +17,13 @@ def _safe_filter(query): return query.first() raise Exception(f"Got more than one matching: {list(query)}") -class QueryConstruct(namedtuple('QueryConstruct', 'collection objectformatter query join_cache tree_rank_count internal_filters')): + +class QueryConstruct( + namedtuple( + "QueryConstruct", + "collection objectformatter query join_cache tree_rank_count internal_filters", + ) +): def __new__(cls, *args, **kwargs): kwargs['join_cache'] = dict() @@ -27,7 +33,7 @@ def __new__(cls, *args, **kwargs): kwargs['internal_filters'] = [] return super(QueryConstruct, cls).__new__(cls, *args, **kwargs) - def handle_tree_field(self, node, table, tree_rank, tree_field): + def handle_tree_field(self, node, table, tree_rank, tree_field, tree_def_id=None): query = self if query.collection is None: raise AssertionError( # Not sure it makes sense to query across collections f"No Collection found in Query for {table}", @@ -42,19 +48,18 @@ def handle_tree_field(self, node, table, tree_rank, tree_field): logger.debug("using join cache for %r tree ranks.", table) ancestors, treedefs = query.join_cache[(table, 'TreeRanks')] else: - - treedefs = get_treedefs(query.collection, table.name) + + treedefs = get_treedefs(query.collection, table.name, tree_def_id) # We need to take the max here. Otherwise, it is possible that the same rank # name may not occur at the same level across tree defs. max_depth = max(depth for _, depth in treedefs) - + ancestors = [node] for _ in range(max_depth-1): ancestor = orm.aliased(node) query = query.outerjoin(ancestor, ancestors[-1].ParentID == getattr(ancestor, ancestor._id)) ancestors.append(ancestor) - logger.debug("adding to join cache for %r tree ranks.", table) query = query._replace(join_cache=query.join_cache.copy()) @@ -63,10 +68,21 @@ def handle_tree_field(self, node, table, tree_rank, tree_field): item_model = getattr(spmodels, table.django_name + "treedefitem") # TODO: optimize out the ranks that appear? cache them - treedefs_with_ranks: List[Tuple[int, int]] = [tup for tup in [ - (treedef_id, _safe_filter(item_model.objects.filter(treedef_id=treedef_id, name=tree_rank).values_list('id', flat=True))) - for treedef_id, _ in treedefs - ] if tup[1] is not None] + treedefs_with_ranks: List[Tuple[int, int]] = [ + tup + for tup in [ + ( + treedef_id, + _safe_filter( + item_model.objects.filter( + treedef_id=treedef_id, name=tree_rank + ).values_list("id", flat=True) + ), + ) + for treedef_id, _ in treedefs + ] + if tup[1] is not None + ] assert len(treedefs_with_ranks) >= 1, "Didn't find the tree rank across any tree" @@ -76,16 +92,23 @@ def handle_tree_field(self, node, table, tree_rank, tree_field): def _predicates_for_node(_node): return [ - # TEST: consider taking the treedef_id comparison just to the first node, if it speeds things up (matching for higher is redundant..) - (sql.and_(getattr(_node, treedef_column)==treedef_id, getattr(_node, treedefitem_column)==treedefitem_id), getattr(_node, column_name)) + # TEST: consider taking the treedef_id comparison just to the first node, + # if it speeds things up (matching for higher is redundant..) + ( + sql.and_( + getattr(_node, treedef_column) == treedef_id, + getattr(_node, treedefitem_column) == treedefitem_id, + ), + getattr(_node, column_name), + ) for (treedef_id, treedefitem_id) in treedefs_with_ranks ] - + cases_per_ancestor = [ _predicates_for_node(ancestor) for ancestor in ancestors - ] - + ] + column = sql.case([case for per_ancestor in cases_per_ancestor for case in per_ancestor]) defs_to_filter_on = [def_id for (def_id, _) in treedefs_with_ranks] @@ -135,7 +158,6 @@ def build_join(self, table, model, join_path): table, model = next_table, aliased return query, model, table, field - # To make things "simpler", it doesn't apply any filters, but returns a single predicate # @model is an input parameter, because cannot guess if it is aliased or not (callers are supposed to know that) def get_internal_filters(self): diff --git a/specifyweb/stored_queries/queryfieldspec.py b/specifyweb/stored_queries/queryfieldspec.py index b24ee42da7b..ab192b5b24c 100644 --- a/specifyweb/stored_queries/queryfieldspec.py +++ b/specifyweb/stored_queries/queryfieldspec.py @@ -24,6 +24,11 @@ # Pull out author or groupnumber field from taxon query fields. TAXON_FIELD_RE = re.compile(r'(.*) ((Author)|(groupNumber))$') +# MOTs tree query for a specify taxon tree. +# Schema: ..,, +# ex. 4.taxon.1,Kingdom,Author +TAXON_MOT_FIELD_RE = re.compile(r'^(\d*),([^,]*),(.*)$') + # Pull out geographyCode field from geography query fields. GEOGRAPHY_FIELD_RE = re.compile(r'(.*) ((geographyCode))$') @@ -61,7 +66,12 @@ def make_stringid(fs, table_list): return table_list, fs.table.name.lower(), field_name -class QueryFieldSpec(namedtuple("QueryFieldSpec", "root_table root_sql_table join_path table date_part tree_rank tree_field")): +class QueryFieldSpec( + namedtuple( + "QueryFieldSpec", + "root_table root_sql_table join_path table date_part tree_rank tree_field treedef_id", + ) +): @classmethod def from_path(cls, path_in, add_id=False): path = deque(path_in) @@ -88,8 +98,8 @@ def from_path(cls, path_in, add_id=False): table=node, date_part='Full Date' if (join_path and join_path[-1].is_temporal()) else None, tree_rank=None, - tree_field=None) - + tree_field=None, + treedef_id=None) @classmethod def from_stringid(cls, stringid, is_relation): @@ -114,15 +124,27 @@ def from_stringid(cls, stringid, is_relation): extracted_fieldname, date_part = extract_date_part(field_name) field = node.get_field(extracted_fieldname, strict=False) - tree_rank = tree_field = None + treedef_id = tree_rank = tree_field = None if field is None: tree_id_match = TREE_ID_FIELD_RE.match(extracted_fieldname) if tree_id_match: tree_rank = tree_id_match.group(1) tree_field = 'ID' else: - tree_field_match = TAXON_FIELD_RE.match(extracted_fieldname) if node is datamodel.get_table('Taxon') else GEOGRAPHY_FIELD_RE.match(extracted_fieldname) if node is datamodel.get_table('Geography') else None - if tree_field_match: + tree_mot_field_match = tree_field_match = None + if node is datamodel.get_table("Taxon"): + tree_mot_field_match = TAXON_MOT_FIELD_RE.match(extracted_fieldname) + tree_field_match = TAXON_FIELD_RE.match(extracted_fieldname) + elif node is datamodel.get_table("Geography"): + tree_field_match = GEOGRAPHY_FIELD_RE.match(extracted_fieldname) + else: + tree_field_match = None + + if tree_mot_field_match: + treedef_id = tree_mot_field_match.group(1) + tree_rank = tree_mot_field_match.group(2) + tree_field = tree_mot_field_match.group(3) + elif tree_field_match: tree_rank = tree_field_match.group(1) tree_field = tree_field_match.group(2) else: @@ -138,7 +160,8 @@ def from_stringid(cls, stringid, is_relation): table=node, date_part=date_part, tree_rank=tree_rank, - tree_field=tree_field) + tree_field=tree_field, + treedef_id=treedef_id) logger.debug('parsed %s (is_relation %s) to %s. extracted_fieldname = %s', stringid, is_relation, result, extracted_fieldname) @@ -195,7 +218,12 @@ def is_auditlog_obj_format_field(self, formatauditobjs): return self.get_field().name.lower() in ['oldvalue','newvalue'] def is_specify_username_end(self): - return len(self.join_path) > 2 and self.join_path[-1].name == 'name' and self.join_path[-2].is_relationship and self.join_path[-2].relatedModelName == 'SpecifyUser' + return ( + len(self.join_path) > 2 + and self.join_path[-1].name == "name" + and self.join_path[-2].is_relationship + and self.join_path[-2].relatedModelName == "SpecifyUser" + ) def apply_filter(self, query, orm_field, field, table, value=None, op_num=None, negate=False): no_filter = op_num is None or (self.tree_rank is None and self.get_field() is None) @@ -241,11 +269,19 @@ def add_spec_to_query(self, query, formatter=None, aggregator=None, cycle_detect query, orm_field = query.objectformatter.objformat(query, orm_model, formatter, cycle_detector) else: query, orm_model, table, field = self.build_join(query, self.join_path[:-1]) - orm_field = query.objectformatter.aggregate(query, self.get_field(), orm_model, aggregator or formatter, cycle_detector) + orm_field = query.objectformatter.aggregate( + query, + self.get_field(), + orm_model, + aggregator or formatter, + cycle_detector, + ) else: query, orm_model, table, field = self.build_join(query, self.join_path) if self.tree_rank is not None: - query, orm_field = query.handle_tree_field(orm_model, table, self.tree_rank, self.tree_field) + query, orm_field = query.handle_tree_field( + orm_model, table, self.tree_rank, self.tree_field, self.treedef_id + ) else: orm_field = getattr(orm_model, self.get_field().name)