diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index 993d9f91fcb3..a1925f76d216 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -31,8 +31,11 @@ def as_oracle(self, compiler, connection, **extra_context): template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' return self.as_sql(compiler, connection, template=template, tolerance=tolerance, **extra_context) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): + c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation) for expr in c.get_source_expressions(): if not hasattr(expr.field, 'geom_type'): raise ValueError('Geospatial aggregates only allowed on geometry fields.') diff --git a/django/contrib/postgres/search.py b/django/contrib/postgres/search.py index 08d92e3514bd..f3d1c8a6f1be 100644 --- a/django/contrib/postgres/search.py +++ b/django/contrib/postgres/search.py @@ -58,13 +58,20 @@ def __init__(self, *expressions, **extra): weight = Value(weight) self.weight = weight - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) if self.config: if not hasattr(self.config, 'resolve_expression'): - resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save) + resolved.config = Value(self.config).resolve_expression( + query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation + ) else: - resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save) + resolved.config = self.config.resolve_expression( + query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation + ) return resolved def as_sql(self, compiler, connection, function=None, template=None): @@ -144,13 +151,20 @@ def __init__(self, value, output_field=None, *, config=None, invert=False, searc self.search_type = search_type super().__init__(value, output_field=output_field) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) if self.config: if not hasattr(self.config, 'resolve_expression'): - resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save) + resolved.config = Value(self.config).resolve_expression( + query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation + ) else: - resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save) + resolved.config = self.config.resolve_expression( + query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation + ) return resolved def as_sql(self, compiler, connection): diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index ea88c54b0d1c..ced9c48c9083 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -42,10 +42,15 @@ def set_source_expressions(self, exprs): self.filter = self.filter and exprs.pop() return super().set_source_expressions(exprs) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): # Aggregates are not allowed in UPDATE queries, so ignore for_save c = super().resolve_expression(query, allow_joins, reuse, summarize) - c.filter = c.filter and c.filter.resolve_expression(query, allow_joins, reuse, summarize) + c.filter = c.filter and c.filter.resolve_expression( + query, allow_joins, reuse, summarize, reuse_with_filtered_relation=reuse_with_filtered_relation + ) if not summarize: # Call Aggregate.get_source_expressions() to avoid # returning self.filter and including that in this loop. diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index a67de51cdc80..d66eda804c09 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -220,7 +220,10 @@ def contains_over_clause(self): def contains_column_references(self): return any(expr and expr.contains_column_references for expr in self.get_source_expressions()) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): """ Provide the chance to do any preprocessing or validation before being added to the query. @@ -232,13 +235,16 @@ def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize * reuse: a set of reusable joins for multijoins * summarize: a terminal aggregate clause * for_save: whether this expression about to be used in a save or update + * reuse_with_filtered_relation: if the reuse param should be used with FilteredRelations Return: an Expression to be added to the query. """ c = self.copy() c.is_summary = summarize c.set_source_expressions([ - expr.resolve_expression(query, allow_joins, reuse, summarize) + expr.resolve_expression( + query, allow_joins, reuse, summarize, reuse_with_filtered_relation=reuse_with_filtered_relation + ) if expr else None for expr in c.get_source_expressions() ]) @@ -443,11 +449,14 @@ def as_sql(self, compiler, connection): sql = connection.ops.combine_expression(self.connector, expressions) return expression_wrapper % sql, expression_params - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): c = self.copy() c.is_summary = summarize - c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) - c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.lhs = c.lhs.resolve_expression(query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation) + c.rhs = c.rhs.resolve_expression(query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation) return c @@ -510,8 +519,10 @@ def __repr__(self): return "{}({})".format(self.__class__.__name__, self.name) def resolve_expression(self, query=None, allow_joins=True, reuse=None, - summarize=False, for_save=False, simple_col=False): - return query.resolve_ref(self.name, allow_joins, reuse, summarize, simple_col) + summarize=False, for_save=False, simple_col=False, reuse_with_filtered_relation=False): + return query.resolve_ref( + self.name, allow_joins, reuse, summarize, simple_col, reuse_with_filtered_relation + ) def asc(self, **kwargs): return OrderBy(self, **kwargs) @@ -548,7 +559,8 @@ def relabeled_clone(self, relabels): class OuterRef(F): def resolve_expression(self, query=None, allow_joins=True, reuse=None, - summarize=False, for_save=False, simple_col=False): + summarize=False, for_save=False, simple_col=False, + reuse_with_filtered_relation=False): if isinstance(self.name, self.__class__): return self.name return ResolvedOuterRef(self.name) @@ -596,11 +608,16 @@ def get_source_expressions(self): def set_source_expressions(self, exprs): self.source_expressions = exprs - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): c = self.copy() c.is_summary = summarize for pos, arg in enumerate(c.source_expressions): - c.source_expressions[pos] = arg.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.source_expressions[pos] = arg.resolve_expression( + query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation + ) return c def as_sql(self, compiler, connection, function=None, template=None, arg_joiner=None, **extra_context): @@ -666,8 +683,11 @@ def as_sql(self, compiler, connection): return 'NULL', [] return '%s', [val] - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): + c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation) c.for_save = for_save return c @@ -801,7 +821,10 @@ def get_source_expressions(self): def set_source_expressions(self, exprs): self.source, = exprs - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): # The sub-expression `source` has already been resolved, as this is # just a reference to the name of `source`. return self @@ -886,12 +909,19 @@ def get_source_fields(self): # We're only interested in the fields of the result expressions. return [self.result._output_field_or_none] - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): c = self.copy() c.is_summary = summarize if hasattr(c.condition, 'resolve_expression'): - c.condition = c.condition.resolve_expression(query, allow_joins, reuse, summarize, False) - c.result = c.result.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.condition = c.condition.resolve_expression( + query, allow_joins, reuse, summarize, False, reuse_with_filtered_relation + ) + c.result = c.result.resolve_expression( + query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation + ) return c def as_sql(self, compiler, connection, template=None, **extra_context): @@ -950,12 +980,19 @@ def get_source_expressions(self): def set_source_expressions(self, exprs): *self.cases, self.default = exprs - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): c = self.copy() c.is_summary = summarize for pos, case in enumerate(c.cases): - c.cases[pos] = case.resolve_expression(query, allow_joins, reuse, summarize, for_save) - c.default = c.default.resolve_expression(query, allow_joins, reuse, summarize, for_save) + c.cases[pos] = case.resolve_expression( + query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation + ) + c.default = c.default.resolve_expression( + query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation + ) return c def copy(self): @@ -1014,7 +1051,10 @@ def copy(self): clone.queryset = clone.queryset.all() return clone - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): clone = self.copy() clone.is_summary = summarize clone.queryset.query.bump_prefix(query) @@ -1031,6 +1071,7 @@ def resolve(child): resolved = child.resolve_expression( query=query, allow_joins=allow_joins, reuse=reuse, summarize=summarize, for_save=for_save, + reuse_with_filtered_relation=reuse_with_filtered_relation ) # Add table alias to the parent query's aliases to prevent # quoting. diff --git a/django/db/models/functions/datetime.py b/django/db/models/functions/datetime.py index 177715ecfaeb..9704822102ac 100644 --- a/django/db/models/functions/datetime.py +++ b/django/db/models/functions/datetime.py @@ -60,8 +60,11 @@ def as_sql(self, compiler, connection): assert False, "Tried to Extract from an invalid type." return sql, params - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): + copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation) field = copy.lhs.output_field if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)): raise ValueError( @@ -187,8 +190,11 @@ def as_sql(self, compiler, connection): raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.') return sql, inner_params - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): - copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): + copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save, reuse_with_filtered_relation) field = copy.lhs.output_field # DateTimeField is a subclass of DateField so this works for both. assert isinstance(field, (DateField, TimeField)), ( diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index f6bc0bd030de..b8ac5ee9fab5 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -87,7 +87,10 @@ def __invert__(self): obj.negate() return obj - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, + summarize=False, for_save=False, reuse_with_filtered_relation=False + ): # We must promote any new joins to left outer joins so that when Q is # used as an expression, rows aren't filtered due to joins. clause, joins = query._add_q(self, reuse, allow_joins=allow_joins, split_subq=False) @@ -332,5 +335,14 @@ def resolve_expression(self, *args, **kwargs): def as_sql(self, compiler, connection): # Resolve the condition in Join.filtered_relation. query = compiler.query - where = query.build_filtered_relation_q(self.condition, reuse=set(self.path)) + # Add other usable aliases in the query to the reuse set. + # Check for if it can be used is in Query.join + reusable_aliases = self.path + # TODO - determine which from the alias_map have actually been applied as joins, + # TODO limit to those and make sure enough are added in the setup_joins + reusable_aliases += list(query.alias_map.keys()) + # import pdb + # pdb.set_trace() + # print("reusable_aliases", reusable_aliases) + where = query.build_filtered_relation_q(self.condition, reuse=set(reusable_aliases)) return compiler.compile(where) diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py index 9c3bb05e89f0..bbe8eae68c30 100644 --- a/django/db/models/sql/datastructures.py +++ b/django/db/models/sql/datastructures.py @@ -87,7 +87,9 @@ def as_sql(self, compiler, connection): join_conditions.append('(%s)' % extra_sql) params.extend(extra_params) if self.filtered_relation: + # print("compiling filtered relation", self.filtered_relation.relation_name) extra_sql, extra_params = compiler.compile(self.filtered_relation) + # print("compiled filtered relation", extra_sql, extra_params) if extra_sql: join_conditions.append('(%s)' % extra_sql) params.extend(extra_params) @@ -98,7 +100,11 @@ def as_sql(self, compiler, connection): "Join generated an empty ON clause. %s did not yield either " "joining columns or extra restrictions." % declared_field.__class__ ) + # TODO - on clause aliases specified here. too deep to change alias though on_clause_sql = ' AND '.join(join_conditions) + # if self.table_name in ('filtered_relation_editor',) or self.table_alias in ('T4',): + # import pdb; + # pdb.set_trace() alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias) sql = '%s %s%s ON (%s)' % (self.join_type, qn(self.table_name), alias_str, on_clause_sql) return sql, params @@ -122,6 +128,8 @@ def equals(self, other, with_filtered_relation): self.table_name == other.table_name and self.parent_alias == other.parent_alias and self.join_field == other.join_field and + # TODO - do any of these checks need changing? + # Like comparing the parent alias with a filtered relation (not with_filtered_relation or self.filtered_relation == other.filtered_relation) ) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b99f0e90efad..e61c54db026e 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -1017,7 +1017,7 @@ def resolve_expression(self, query, *args, **kwargs): def as_sql(self, compiler, connection): return self.get_compiler(connection=connection).as_sql() - def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col): + def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col, reuse_with_filtered_relation=False): if hasattr(value, 'resolve_expression'): kwargs = {'reuse': can_reuse, 'allow_joins': allow_joins} if isinstance(value, F): @@ -1034,7 +1034,10 @@ def resolve_lookup_value(self, value, can_reuse, allow_joins, simple_col): simple_col=simple_col, ) else: - sub_value.resolve_expression(self, reuse=can_reuse, allow_joins=allow_joins) + sub_value.resolve_expression( + self, reuse=can_reuse, allow_joins=allow_joins, + reuse_with_filtered_relation=reuse_with_filtered_relation + ) return value def solve_lookup_type(self, lookup): @@ -1199,7 +1202,7 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False, raise FieldError("Joined field references are not permitted in this query") pre_joins = self.alias_refcount.copy() - value = self.resolve_lookup_value(value, can_reuse, allow_joins, simple_col) + value = self.resolve_lookup_value(value, can_reuse, allow_joins, simple_col, reuse_with_filtered_relation) used_joins = {k for k, v in self.alias_refcount.items() if v > pre_joins.get(k, 0)} clause = self.where_class() @@ -1217,6 +1220,26 @@ def build_filter(self, filter_expr, branch_negated=False, current_negated=False, parts, opts, alias, can_reuse=can_reuse, allow_many=allow_many, reuse_with_filtered_relation=reuse_with_filtered_relation, ) + # JoinInfo( + # final_field=, + # targets=(,), + # opts=, + # joins=['filtered_relation_author', 'book_edited_by_b'], + # path=[ + # PathInfo(from_opts=, to_opts=, + # target_fields=(,), + # join_field=, + # m2m=True, + # direct=False, + # filtered_relation=) + # ], + # transform_function=.final_transformer at 0x10845aea0> + # ) + # for p in join_info.path: + # if p.filtered_relation: + # # TODO add joins for filtered_relation.condition + # import pdb + # pdb.set_trace() # Prevent iterator from being consumed by check_related_objects() if isinstance(value, Iterator): @@ -1346,16 +1369,39 @@ def build_filtered_relation_q(self, q_object, reuse, branch_negated=False, curre return target_clause def add_filtered_relation(self, filtered_relation, alias): + # TODO add joins needed for condition, but don't apply condition to where clause filtered_relation.alias = alias + nested = False lookups = dict(get_children_from_q(filtered_relation.condition)) - for lookup in chain((filtered_relation.relation_name,), lookups): + relation_lookup_parts, relation_field_parts, _ = self.solve_lookup_type(filtered_relation.relation_name) + if relation_lookup_parts: + raise ValueError( + "FilteredRelation's relation_name doesn't support lookups " + "(got %r)." % filtered_relation.relation_name + ) + for lookup in chain(lookups): lookup_parts, field_parts, _ = self.solve_lookup_type(lookup) shift = 2 if not lookup_parts else 1 - if len(field_parts) > (shift + len(lookup_parts)): + # if the lookup field relationship is + lookup_path = field_parts[:-shift] + for relation_field_part in relation_field_parts: + if lookup_path: + if relation_field_part != lookup_path[0]: + # lookup isn't in the path, need to add path to lookup + pass + lookup_path = lookup_path[1:] + if lookup_path: raise ValueError( "FilteredRelation's condition doesn't support nested " - "relations (got %r)." % lookup + "on clauses deeper than the relation (got %r for %r)." % (lookup, filtered_relation.relation_name) ) + + # if nested: + # self.build_filtered_relation_q( + # filtered_relation.condition, + # set(list(self.alias_map.keys())) + # ) + # clause, inners = self._add_q(filtered_relation.condition, self.used_aliases) self._filtered_relations[filtered_relation.alias] = filtered_relation def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): @@ -1388,7 +1434,18 @@ def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): field = self.annotation_select[name].output_field elif name in self._filtered_relations and pos == 0: filtered_relation = self._filtered_relations[name] - field = opts.get_field(filtered_relation.relation_name) + if LOOKUP_SEP in filtered_relation.relation_name: + parts = filtered_relation.relation_name.split(LOOKUP_SEP) + filtered_relation_path, field, _, _ = self.names_to_path( + parts, opts, allow_many, fail_on_missing + ) + # TODO - add in paths for filtered_relation.condition? + # import pdb + # pdb.set_trace() + # don't want the last one, because that one gets added on below and would be duplicate + path.extend(filtered_relation_path[:-1]) + else: + field = opts.get_field(filtered_relation.relation_name) if field is not None: # Fields that contain one-to-many relations with a generic # model (like a GenericForeignKey) cannot generate reverse @@ -1430,6 +1487,9 @@ def names_to_path(self, names, opts, allow_many=True, fail_on_missing=False): opts = path_to_parent[-1].to_opts if hasattr(field, 'get_path_info'): pathinfos = field.get_path_info(filtered_relation) + # if filtered_relation: + # import pdb + # pdb.set_trace() if not allow_many: for inner_pos, p in enumerate(pathinfos): if p.m2m: @@ -1553,7 +1613,19 @@ def transform(field, alias, *, name, previous): ) joins.append(alias) if filtered_relation: + # filtered_relation_aliases = list(set(joins[:] + list(self.alias_map.keys()))) + # filtered_relation.path = filtered_relation_aliases filtered_relation.path = joins[:] + # TODO - set more filtered_relation.PATH here so we don't need to in query_utils + # self.build_filter(filtered_relation.condition) + # filtered_relation.condition.resolve_expression(self) + # expr = filtered_relation.condition.resolve_expression(self) + # q = self.build_filter(expr) + # self.build_filtered_relation_q(filtered_relation.condition) + # self.add_q(filtered_relation.condition) + # self.setup_joins() + # import pdb + # pdb.set_trace() return JoinInfo(final_field, targets, opts, joins, path, final_transformer) def trim_joins(self, targets, joins, path): @@ -1584,7 +1656,10 @@ def trim_joins(self, targets, joins, path): self.unref_alias(joins.pop()) return targets, joins[-1], joins - def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False, simple_col=False): + def resolve_ref( + self, name, allow_joins=True, reuse=None, + summarize=False, simple_col=False, reuse_with_filtered_relation=False + ): if not allow_joins and LOOKUP_SEP in name: raise FieldError("Joined field references are not permitted in this query") if name in self.annotations: @@ -1598,7 +1673,10 @@ def resolve_ref(self, name, allow_joins=True, reuse=None, summarize=False, simpl return self.annotations[name] else: field_list = name.split(LOOKUP_SEP) - join_info = self.setup_joins(field_list, self.get_meta(), self.get_initial_alias(), can_reuse=reuse) + join_info = self.setup_joins( + field_list, self.get_meta(), self.get_initial_alias(), + can_reuse=reuse, reuse_with_filtered_relation=reuse_with_filtered_relation + ) targets, final_alias, join_list = self.trim_joins(join_info.targets, join_info.joins, join_info.path) if not allow_joins and len(join_list) > 1: raise FieldError('Joined field references are not permitted in this query') diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index f23d4bc59817..5afc999b3ed1 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -3582,17 +3582,6 @@ operate on vegetarian pizzas. ``FilteredRelation`` doesn't support: -* Conditions that span relational fields. For example:: - - >>> Restaurant.objects.annotate( - ... pizzas_with_toppings_startswith_n=FilteredRelation( - ... 'pizzas__toppings', - ... condition=Q(pizzas__toppings__name__startswith='n'), - ... ), - ... ) - Traceback (most recent call last): - ... - ValueError: FilteredRelation's condition doesn't support nested relations (got 'pizzas__toppings__name__startswith'). * :meth:`.QuerySet.only` and :meth:`~.QuerySet.prefetch_related`. * A :class:`~django.contrib.contenttypes.fields.GenericForeignKey` inherited from a parent model. diff --git a/tests/filtered_relation/tests.py b/tests/filtered_relation/tests.py index 2596dcbdc222..1f0badb88585 100644 --- a/tests/filtered_relation/tests.py +++ b/tests/filtered_relation/tests.py @@ -1,5 +1,7 @@ from django.db import connection, transaction -from django.db.models import Case, Count, F, FilteredRelation, Q, When +from django.db.models import ( + BooleanField, Case, Count, F, FilteredRelation, Q, When, +) from django.test import TestCase from django.test.testcases import skipUnlessDBFeature @@ -220,10 +222,11 @@ def test_difference(self): self.assertSequenceEqual(qs1.difference(qs2), [self.author1]) def test_select_for_update(self): + qs = Author.objects.annotate( + book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), + ).filter(book_jane__isnull=False).select_for_update() self.assertSequenceEqual( - Author.objects.annotate( - book_jane=FilteredRelation('book', condition=Q(book__title__iexact='the book by jane a')), - ).filter(book_jane__isnull=False).select_for_update(), + qs, [self.author2] ) @@ -251,28 +254,175 @@ def test_as_subquery(self): qs = Author.objects.filter(id__in=inner_qs) self.assertSequenceEqual(qs, [self.author1]) - def test_with_foreign_key_error(self): + def test_with_nested_foreign_key(self): + qs = Author.objects.annotate( + book_editor_worked_with=FilteredRelation('book__editor', condition=Q(book__title__icontains='book')), + ).filter( + book_editor_worked_with__isnull=False + ).select_related( + 'book_editor_worked_with' + ).order_by( + 'pk', 'book_editor_worked_with__pk' + ).distinct() + + with self.assertNumQueries(1): + self.assertQuerysetEqual(qs, [ + (self.author1, self.editor_a), + (self.author2, self.editor_b), + ], lambda x: (x, x.book_editor_worked_with)) + + def test_with_nested_field(self): + qs = Author.objects.annotate( + book_editor_worked_with=FilteredRelation('book__editor', condition=Q(book__title__icontains='book')), + ).filter( + book_editor_worked_with__isnull=False + ).values( + 'name', 'book_editor_worked_with__name' + ).order_by( + 'name', 'book_editor_worked_with__name' + ).distinct() + + self.assertSequenceEqual(qs, [ + {'name': self.author1.name, 'book_editor_worked_with__name': self.editor_a.name}, + {'name': self.author2.name, 'book_editor_worked_with__name': self.editor_b.name}, + ]) + + def test_with_deep_nested_foreign_key(self): + qs = Book.objects.annotate( + author_favorite_book_editor=FilteredRelation( + 'author__favorite_books__editor', + condition=Q(author__favorite_books__title__icontains='Jane A') + ) + ).filter( + author_favorite_book_editor__isnull=False + ).select_related( + 'author_favorite_book_editor' + ).order_by( + 'pk', 'author_favorite_book_editor__pk' + ) + + with self.assertNumQueries(1): + self.assertQuerysetEqual(qs, [ + (self.book1, self.editor_b), + (self.book4, self.editor_b), + ], lambda x: (x, x.author_favorite_book_editor)) + + def test_with_foreign_key_on_condition_deeper_than_relationship_error(self): msg = ( - "FilteredRelation's condition doesn't support nested relations " - "(got 'author__favorite_books__author')." + "FilteredRelation's condition doesn't support nested " + "on clauses deeper than the relation (got 'book__editor__name__icontains' for 'book')." ) with self.assertRaisesMessage(ValueError, msg): - list(Book.objects.annotate( - alice_favorite_books=FilteredRelation( - 'author__favorite_books', - condition=Q(author__favorite_books__author=self.author1), - ) - )) + qs = Author.objects.annotate( + book_edited_by_b=FilteredRelation('book', condition=Q(book__editor__name__icontains='b')), + ).filter(book_edited_by_b__isnull=False).distinct() + list(qs) - def test_with_foreign_key_on_condition_error(self): + def test_with_foreign_key_relation_name_lookup(self): msg = ( - "FilteredRelation's condition doesn't support nested relations " - "(got 'book__editor__name__icontains')." + "FilteredRelation's relation_name doesn't support lookups " + "(got 'book__title__icontains')." ) with self.assertRaisesMessage(ValueError, msg): - list(Author.objects.annotate( - book_edited_by_b=FilteredRelation('book', condition=Q(book__editor__name__icontains='b')), - )) + qs = Author.objects.annotate( + book_edited_by_b=FilteredRelation( + 'book__title__icontains', + condition=Q(book__editor__name__icontains='b') + ), + ).filter(book_edited_by_b__isnull=False).distinct() + list(qs) + + def test_with_nested_filtered_relationship(self): + qs = Book.objects.annotate( + author_favorite_books=FilteredRelation( + 'author__favorite_books', + condition=Q(author__favorite_books__title__icontains='book') + ) + ) + qs = qs.values( + 'title', 'author__name', 'author_favorite_books__title' + ).order_by( + 'title', 'author__name' + ) + + self.assertSequenceEqual(qs, [ + { + 'title': self.book1.title, 'author__name': self.author1.name, + 'author_favorite_books__title': self.book2.title, + }, + { + 'title': self.book1.title, 'author__name': self.author1.name, + 'author_favorite_books__title': self.book3.title, + }, + { + 'title': self.book4.title, 'author__name': self.author1.name, + 'author_favorite_books__title': self.book2.title, + }, + { + 'title': self.book4.title, 'author__name': self.author1.name, + 'author_favorite_books__title': self.book3.title, + }, + { + 'title': self.book2.title, 'author__name': self.author2.name, + 'author_favorite_books__title': None, + }, + { + 'title': self.book3.title, 'author__name': self.author2.name, + 'author_favorite_books__title': None, + }, + ]) + + def test_with_multiple_relations_interacting_with_nested_on_condition(self): + qs = Author.objects.annotate( + book_editors_worked_with=FilteredRelation( + 'book__editor', condition=Q(book__title__icontains='book') + ) + ) + qs = qs.annotate( + favorite_books_by_editor_worked_with=FilteredRelation( + 'favorite_books', + condition=Q( + book__title__icontains='book', + favorite_books__editor=F("book_editors_worked_with") + ) + ) + ) + qs = qs.annotate( + author=F('name'), + book_editor_name=F("book_editors_worked_with__name"), + favorite_by_worked_with_editor=Case( + When(favorite_books_by_editor_worked_with__isnull=False, then=True), + default=False, + output_field=BooleanField() + ), + ) + qs = qs.values('author', 'book__title', 'book_editor_name', 'favorite_by_worked_with_editor').distinct() + self.assertSequenceEqual(qs, [ + { + 'book__title': self.book1.title, + 'author': self.author1.name, + 'book_editor_name': None, + 'favorite_by_worked_with_editor': False + }, + { + 'book__title': self.book4.title, + 'author': self.author1.name, + 'book_editor_name': self.editor_a.name, + 'favorite_by_worked_with_editor': False + }, + { + 'book__title': self.book2.title, + 'author': self.author2.name, + 'book_editor_name': self.editor_b.name, + 'favorite_by_worked_with_editor': False + }, + { + 'book__title': self.book3.title, + 'author': self.author2.name, + 'book_editor_name': self.editor_b.name, + 'favorite_by_worked_with_editor': False + } + ]) def test_with_empty_relation_name_error(self): with self.assertRaisesMessage(ValueError, 'relation_name cannot be empty.'):