测试gitnore

This commit is contained in:
ladeng07
2022-05-06 15:45:57 +08:00
parent 12f390949b
commit 51552904f9
2347 changed files with 120102 additions and 53549 deletions
@@ -5,34 +5,14 @@ from django.db.models.aggregates import __all__ as aggregates_all
from django.db.models.constraints import * # NOQA
from django.db.models.constraints import __all__ as constraints_all
from django.db.models.deletion import (
CASCADE,
DO_NOTHING,
PROTECT,
RESTRICT,
SET,
SET_DEFAULT,
SET_NULL,
ProtectedError,
RestrictedError,
CASCADE, DO_NOTHING, PROTECT, RESTRICT, SET, SET_DEFAULT, SET_NULL,
ProtectedError, RestrictedError,
)
from django.db.models.enums import * # NOQA
from django.db.models.enums import __all__ as enums_all
from django.db.models.expressions import (
Case,
Exists,
Expression,
ExpressionList,
ExpressionWrapper,
F,
Func,
OrderBy,
OuterRef,
RowRange,
Subquery,
Value,
ValueRange,
When,
Window,
Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func,
OrderBy, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window,
WindowFrame,
)
from django.db.models.fields import * # NOQA
@@ -50,66 +30,23 @@ from django.db.models.query_utils import FilteredRelation, Q
# Imports that would create circular imports if sorted
from django.db.models.base import DEFERRED, Model # isort:skip
from django.db.models.fields.related import ( # isort:skip
ForeignKey,
ForeignObject,
OneToOneField,
ManyToManyField,
ForeignObjectRel,
ManyToOneRel,
ManyToManyRel,
OneToOneRel,
ForeignKey, ForeignObject, OneToOneField, ManyToManyField,
ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel,
)
__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
__all__ += [
"ObjectDoesNotExist",
"signals",
"CASCADE",
"DO_NOTHING",
"PROTECT",
"RESTRICT",
"SET",
"SET_DEFAULT",
"SET_NULL",
"ProtectedError",
"RestrictedError",
"Case",
"Exists",
"Expression",
"ExpressionList",
"ExpressionWrapper",
"F",
"Func",
"OrderBy",
"OuterRef",
"RowRange",
"Subquery",
"Value",
"ValueRange",
"When",
"Window",
"WindowFrame",
"FileField",
"ImageField",
"JSONField",
"OrderWrt",
"Lookup",
"Transform",
"Manager",
"Prefetch",
"Q",
"QuerySet",
"prefetch_related_objects",
"DEFERRED",
"Model",
"FilteredRelation",
"ForeignKey",
"ForeignObject",
"OneToOneField",
"ManyToManyField",
"ForeignObjectRel",
"ManyToOneRel",
"ManyToManyRel",
"OneToOneRel",
'ObjectDoesNotExist', 'signals',
'CASCADE', 'DO_NOTHING', 'PROTECT', 'RESTRICT', 'SET', 'SET_DEFAULT',
'SET_NULL', 'ProtectedError', 'RestrictedError',
'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
'Func', 'OrderBy', 'OuterRef', 'RowRange', 'Subquery', 'Value',
'ValueRange', 'When',
'Window', 'WindowFrame',
'FileField', 'ImageField', 'JSONField', 'OrderWrt', 'Lookup', 'Transform',
'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects',
'DEFERRED', 'Model', 'FilteredRelation',
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
'ForeignObjectRel', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel',
]
@@ -4,43 +4,28 @@ Classes to represent the definitions of aggregate functions.
from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.mixins import (
FixDurationInputMixin,
NumericOutputFieldMixin,
FixDurationInputMixin, NumericOutputFieldMixin,
)
__all__ = [
"Aggregate",
"Avg",
"Count",
"Max",
"Min",
"StdDev",
"Sum",
"Variance",
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
]
class Aggregate(Func):
template = "%(function)s(%(distinct)s%(expressions)s)"
template = '%(function)s(%(distinct)s%(expressions)s)'
contains_aggregate = True
name = None
filter_template = "%s FILTER (WHERE %%(filter)s)"
filter_template = '%s FILTER (WHERE %%(filter)s)'
window_compatible = True
allow_distinct = False
empty_result_set_value = None
def __init__(
self, *expressions, distinct=False, filter=None, default=None, **extra
):
def __init__(self, *expressions, distinct=False, filter=None, **extra):
if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
if default is not None and self.empty_result_set_value is not None:
raise TypeError(f"{self.__class__.__name__} does not allow default.")
self.distinct = distinct
self.filter = filter
self.default = default
super().__init__(*expressions, **extra)
def get_source_fields(self):
@@ -57,14 +42,10 @@ class Aggregate(Func):
self.filter = self.filter and exprs.pop()
return super().set_source_expressions(exprs)
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
# Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super().resolve_expression(query, allow_joins, reuse, summarize)
c.filter = c.filter and c.filter.resolve_expression(
query, allow_joins, reuse, summarize
)
c.filter = c.filter and c.filter.resolve_expression(query, allow_joins, reuse, summarize)
if not summarize:
# Call Aggregate.get_source_expressions() to avoid
# returning self.filter and including that in this loop.
@@ -72,48 +53,29 @@ class Aggregate(Func):
for index, expr in enumerate(expressions):
if expr.contains_aggregate:
before_resolved = self.get_source_expressions()[index]
name = (
before_resolved.name
if hasattr(before_resolved, "name")
else repr(before_resolved)
)
raise FieldError(
"Cannot compute %s('%s'): '%s' is an aggregate"
% (c.name, name, name)
)
if (default := c.default) is None:
return c
if hasattr(default, "resolve_expression"):
default = default.resolve_expression(query, allow_joins, reuse, summarize)
c.default = None # Reset the default argument before wrapping.
coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
coalesce.is_summary = c.is_summary
return coalesce
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
return c
@property
def default_alias(self):
expressions = self.get_source_expressions()
if len(expressions) == 1 and hasattr(expressions[0], "name"):
return "%s__%s" % (expressions[0].name, self.name.lower())
if len(expressions) == 1 and hasattr(expressions[0], 'name'):
return '%s__%s' % (expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias")
def get_group_by_cols(self, alias=None):
return []
def as_sql(self, compiler, connection, **extra_context):
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
extra_context['distinct'] = 'DISTINCT ' if self.distinct else ''
if self.filter:
if connection.features.supports_aggregate_filter_clause:
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
template = self.filter_template % extra_context.get(
"template", self.template
)
template = self.filter_template % extra_context.get('template', self.template)
sql, params = super().as_sql(
compiler,
connection,
template=template,
filter=filter_sql,
**extra_context,
compiler, connection, template=template, filter=filter_sql,
**extra_context
)
return sql, params + filter_params
else:
@@ -122,74 +84,74 @@ class Aggregate(Func):
source_expressions = copy.get_source_expressions()
condition = When(self.filter, then=source_expressions[0])
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
return super(Aggregate, copy).as_sql(
compiler, connection, **extra_context
)
return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
return super().as_sql(compiler, connection, **extra_context)
def _get_repr_options(self):
options = super()._get_repr_options()
if self.distinct:
options["distinct"] = self.distinct
options['distinct'] = self.distinct
if self.filter:
options["filter"] = self.filter
options['filter'] = self.filter
return options
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
function = "AVG"
name = "Avg"
function = 'AVG'
name = 'Avg'
allow_distinct = True
class Count(Aggregate):
function = "COUNT"
name = "Count"
function = 'COUNT'
name = 'Count'
output_field = IntegerField()
allow_distinct = True
empty_result_set_value = 0
def __init__(self, expression, filter=None, **extra):
if expression == "*":
if expression == '*':
expression = Star()
if isinstance(expression, Star) and filter is not None:
raise ValueError("Star cannot be used with filter. Please specify a field.")
raise ValueError('Star cannot be used with filter. Please specify a field.')
super().__init__(expression, filter=filter, **extra)
def convert_value(self, value, expression, connection):
return 0 if value is None else value
class Max(Aggregate):
function = "MAX"
name = "Max"
function = 'MAX'
name = 'Max'
class Min(Aggregate):
function = "MIN"
name = "Min"
function = 'MIN'
name = 'Min'
class StdDev(NumericOutputFieldMixin, Aggregate):
name = "StdDev"
name = 'StdDev'
def __init__(self, expression, sample=False, **extra):
self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
super().__init__(expression, **extra)
def _get_repr_options(self):
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
class Sum(FixDurationInputMixin, Aggregate):
function = "SUM"
name = "Sum"
function = 'SUM'
name = 'Sum'
allow_distinct = True
class Variance(NumericOutputFieldMixin, Aggregate):
name = "Variance"
name = 'Variance'
def __init__(self, expression, sample=False, **extra):
self.function = "VAR_SAMP" if sample else "VAR_POP"
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
super().__init__(expression, **extra)
def _get_repr_options(self):
return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}
return {**super()._get_repr_options(), 'sample': self.function == 'VAR_SAMP'}
File diff suppressed because it is too large Load Diff
@@ -3,4 +3,4 @@ Constants used across the ORM in general.
"""
# Separator used to split filter strings apart.
LOOKUP_SEP = "__"
LOOKUP_SEP = '__'
@@ -1,34 +1,28 @@
from enum import Enum
from django.db.models.expressions import ExpressionList, F
from django.db.models.indexes import IndexExpression
from django.db.models.query_utils import Q
from django.db.models.sql.query import Query
__all__ = ["CheckConstraint", "Deferrable", "UniqueConstraint"]
__all__ = ['CheckConstraint', 'Deferrable', 'UniqueConstraint']
class BaseConstraint:
def __init__(self, name):
self.name = name
@property
def contains_expressions(self):
return False
def constraint_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
raise NotImplementedError('This method must be implemented by a subclass.')
def create_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
raise NotImplementedError('This method must be implemented by a subclass.')
def remove_sql(self, model, schema_editor):
raise NotImplementedError("This method must be implemented by a subclass.")
raise NotImplementedError('This method must be implemented by a subclass.')
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
path = path.replace("django.db.models.constraints", "django.db.models")
return (path, (), {"name": self.name})
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
path = path.replace('django.db.models.constraints', 'django.db.models')
return (path, (), {'name': self.name})
def clone(self):
_, args, kwargs = self.deconstruct()
@@ -38,9 +32,10 @@ class BaseConstraint:
class CheckConstraint(BaseConstraint):
def __init__(self, *, check, name):
self.check = check
if not getattr(check, "conditional", False):
if not getattr(check, 'conditional', False):
raise TypeError(
"CheckConstraint.check must be a Q instance or boolean expression."
'CheckConstraint.check must be a Q instance or boolean '
'expression.'
)
super().__init__(name)
@@ -63,11 +58,7 @@ class CheckConstraint(BaseConstraint):
return schema_editor._delete_check_sql(model, self.name)
def __repr__(self):
return "<%s: check=%s name=%s>" % (
self.__class__.__qualname__,
self.check,
repr(self.name),
)
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
def __eq__(self, other):
if isinstance(other, CheckConstraint):
@@ -76,84 +67,62 @@ class CheckConstraint(BaseConstraint):
def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs["check"] = self.check
kwargs['check'] = self.check
return path, args, kwargs
class Deferrable(Enum):
DEFERRED = "deferred"
IMMEDIATE = "immediate"
# A similar format was proposed for Python 3.10.
def __repr__(self):
return f"{self.__class__.__qualname__}.{self._name_}"
DEFERRED = 'deferred'
IMMEDIATE = 'immediate'
class UniqueConstraint(BaseConstraint):
def __init__(
self,
*expressions,
fields=(),
name=None,
*,
fields,
name,
condition=None,
deferrable=None,
include=None,
opclasses=(),
):
if not name:
raise ValueError("A unique constraint must be named.")
if not expressions and not fields:
raise ValueError(
"At least one field or expression is required to define a "
"unique constraint."
)
if expressions and fields:
raise ValueError(
"UniqueConstraint.fields and expressions are mutually exclusive."
)
if not fields:
raise ValueError('At least one field is required to define a unique constraint.')
if not isinstance(condition, (type(None), Q)):
raise ValueError("UniqueConstraint.condition must be a Q instance.")
raise ValueError('UniqueConstraint.condition must be a Q instance.')
if condition and deferrable:
raise ValueError("UniqueConstraint with conditions cannot be deferred.")
if include and deferrable:
raise ValueError("UniqueConstraint with include fields cannot be deferred.")
if opclasses and deferrable:
raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
if expressions and deferrable:
raise ValueError("UniqueConstraint with expressions cannot be deferred.")
if expressions and opclasses:
raise ValueError(
"UniqueConstraint.opclasses cannot be used with expressions. "
"Use django.contrib.postgres.indexes.OpClass() instead."
'UniqueConstraint with conditions cannot be deferred.'
)
if include and deferrable:
raise ValueError(
'UniqueConstraint with include fields cannot be deferred.'
)
if opclasses and deferrable:
raise ValueError(
'UniqueConstraint with opclasses cannot be deferred.'
)
if not isinstance(deferrable, (type(None), Deferrable)):
raise ValueError(
"UniqueConstraint.deferrable must be a Deferrable instance."
'UniqueConstraint.deferrable must be a Deferrable instance.'
)
if not isinstance(include, (type(None), list, tuple)):
raise ValueError("UniqueConstraint.include must be a list or tuple.")
raise ValueError('UniqueConstraint.include must be a list or tuple.')
if not isinstance(opclasses, (list, tuple)):
raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
raise ValueError('UniqueConstraint.opclasses must be a list or tuple.')
if opclasses and len(fields) != len(opclasses):
raise ValueError(
"UniqueConstraint.fields and UniqueConstraint.opclasses must "
"have the same number of elements."
'UniqueConstraint.fields and UniqueConstraint.opclasses must '
'have the same number of elements.'
)
self.fields = tuple(fields)
self.condition = condition
self.deferrable = deferrable
self.include = tuple(include) if include else ()
self.opclasses = opclasses
self.expressions = tuple(
F(expression) if isinstance(expression, str) else expression
for expression in expressions
)
super().__init__(name)
@property
def contains_expressions(self):
return bool(self.expressions)
def _get_condition_sql(self, model, schema_editor):
if self.condition is None:
return None
@@ -163,105 +132,64 @@ class UniqueConstraint(BaseConstraint):
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def _get_index_expressions(self, model, schema_editor):
if not self.expressions:
return None
index_expressions = []
for expression in self.expressions:
index_expression = IndexExpression(expression)
index_expression.set_wrapper_classes(schema_editor.connection)
index_expressions.append(index_expression)
return ExpressionList(*index_expressions).resolve_expression(
Query(model, alias_cols=False),
)
def constraint_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
fields = [model._meta.get_field(field_name).column for field_name in self.fields]
include = [model._meta.get_field(field_name).column for field_name in self.include]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._unique_sql(
model,
fields,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
model, fields, self.name, condition=condition,
deferrable=self.deferrable, include=include,
opclasses=self.opclasses,
expressions=expressions,
)
def create_sql(self, model, schema_editor):
fields = [model._meta.get_field(field_name) for field_name in self.fields]
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
fields = [model._meta.get_field(field_name).column for field_name in self.fields]
include = [model._meta.get_field(field_name).column for field_name in self.include]
condition = self._get_condition_sql(model, schema_editor)
expressions = self._get_index_expressions(model, schema_editor)
return schema_editor._create_unique_sql(
model,
fields,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
model, fields, self.name, condition=condition,
deferrable=self.deferrable, include=include,
opclasses=self.opclasses,
expressions=expressions,
)
def remove_sql(self, model, schema_editor):
condition = self._get_condition_sql(model, schema_editor)
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
expressions = self._get_index_expressions(model, schema_editor)
include = [model._meta.get_field(field_name).column for field_name in self.include]
return schema_editor._delete_unique_sql(
model,
self.name,
condition=condition,
deferrable=self.deferrable,
include=include,
opclasses=self.opclasses,
expressions=expressions,
model, self.name, condition=condition, deferrable=self.deferrable,
include=include, opclasses=self.opclasses,
)
def __repr__(self):
return "<%s:%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
"" if not self.fields else " fields=%s" % repr(self.fields),
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
" name=%s" % repr(self.name),
"" if self.condition is None else " condition=%s" % self.condition,
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
"" if not self.include else " include=%s" % repr(self.include),
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
return '<%s: fields=%r name=%r%s%s%s%s>' % (
self.__class__.__name__, self.fields, self.name,
'' if self.condition is None else ' condition=%s' % self.condition,
'' if self.deferrable is None else ' deferrable=%s' % self.deferrable,
'' if not self.include else ' include=%s' % repr(self.include),
'' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
)
def __eq__(self, other):
if isinstance(other, UniqueConstraint):
return (
self.name == other.name
and self.fields == other.fields
and self.condition == other.condition
and self.deferrable == other.deferrable
and self.include == other.include
and self.opclasses == other.opclasses
and self.expressions == other.expressions
self.name == other.name and
self.fields == other.fields and
self.condition == other.condition and
self.deferrable == other.deferrable and
self.include == other.include and
self.opclasses == other.opclasses
)
return super().__eq__(other)
def deconstruct(self):
path, args, kwargs = super().deconstruct()
if self.fields:
kwargs["fields"] = self.fields
kwargs['fields'] = self.fields
if self.condition:
kwargs["condition"] = self.condition
kwargs['condition'] = self.condition
if self.deferrable:
kwargs["deferrable"] = self.deferrable
kwargs['deferrable'] = self.deferrable
if self.include:
kwargs["include"] = self.include
kwargs['include'] = self.include
if self.opclasses:
kwargs["opclasses"] = self.opclasses
return path, self.expressions, kwargs
kwargs['opclasses'] = self.opclasses
return path, args, kwargs
@@ -1,5 +1,6 @@
import operator
from collections import Counter, defaultdict
from functools import partial
from functools import partial, reduce
from itertools import chain
from operator import attrgetter
@@ -21,11 +22,8 @@ class RestrictedError(IntegrityError):
def CASCADE(collector, field, sub_objs, using):
collector.collect(
sub_objs,
source=field.remote_field.model,
source_attr=field.name,
nullable=field.null,
fail_on_restricted=False,
sub_objs, source=field.remote_field.model, source_attr=field.name,
nullable=field.null, fail_on_restricted=False,
)
if field.null and not connections[using].features.can_defer_constraint_checks:
collector.add_field_update(field, None, sub_objs)
@@ -34,13 +32,10 @@ def CASCADE(collector, field, sub_objs, using):
def PROTECT(collector, field, sub_objs, using):
raise ProtectedError(
"Cannot delete some instances of model '%s' because they are "
"referenced through a protected foreign key: '%s.%s'"
% (
field.remote_field.model.__name__,
sub_objs[0].__class__.__name__,
field.name,
"referenced through a protected foreign key: '%s.%s'" % (
field.remote_field.model.__name__, sub_objs[0].__class__.__name__, field.name
),
sub_objs,
sub_objs
)
@@ -51,16 +46,12 @@ def RESTRICT(collector, field, sub_objs, using):
def SET(value):
if callable(value):
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value(), sub_objs)
else:
def set_on_delete(collector, field, sub_objs, using):
collector.add_field_update(field, value, sub_objs)
set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {})
set_on_delete.deconstruct = lambda: ('django.db.models.SET', (value,), {})
return set_on_delete
@@ -80,8 +71,7 @@ def get_candidate_relations_to_delete(opts):
# The candidate relations are the ones that come from N-1 and 1-1 relations.
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
return (
f
for f in opts.get_fields(include_hidden=True)
f for f in opts.get_fields(include_hidden=True)
if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many)
)
@@ -133,9 +123,7 @@ class Collector:
def add_dependency(self, model, dependency, reverse_dependency=False):
if reverse_dependency:
model, dependency = dependency, model
self.dependencies[model._meta.concrete_model].add(
dependency._meta.concrete_model
)
self.dependencies[model._meta.concrete_model].add(dependency._meta.concrete_model)
self.data.setdefault(dependency, self.data.default_factory())
def add_field_update(self, field, value, objs):
@@ -162,21 +150,17 @@ class Collector:
def clear_restricted_objects_from_queryset(self, model, qs):
if model in self.restricted_objects:
objs = set(
qs.filter(
pk__in=[
obj.pk
for objs in self.restricted_objects[model].values()
for obj in objs
]
)
)
objs = set(qs.filter(pk__in=[
obj.pk
for objs in self.restricted_objects[model].values() for obj in objs
]))
self.clear_restricted_objects_from_set(model, objs)
def _has_signal_listeners(self, model):
return signals.pre_delete.has_listeners(
model
) or signals.post_delete.has_listeners(model)
return (
signals.pre_delete.has_listeners(model) or
signals.post_delete.has_listeners(model)
)
def can_fast_delete(self, objs, from_field=None):
"""
@@ -191,9 +175,9 @@ class Collector:
"""
if from_field and from_field.remote_field.on_delete is not CASCADE:
return False
if hasattr(objs, "_meta"):
if hasattr(objs, '_meta'):
model = objs._meta.model
elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
elif hasattr(objs, 'model') and hasattr(objs, '_raw_delete'):
model = objs.model
else:
return False
@@ -203,22 +187,14 @@ class Collector:
# parent when parent delete is cascading to child.
opts = model._meta
return (
all(
link == from_field
for link in opts.concrete_model._meta.parents.values()
)
and
all(link == from_field for link in opts.concrete_model._meta.parents.values()) and
# Foreign keys pointing to this model.
all(
related.field.remote_field.on_delete is DO_NOTHING
for related in get_candidate_relations_to_delete(opts)
)
and (
) and (
# Something like generic foreign key.
not any(
hasattr(field, "bulk_related_objects")
for field in opts.private_fields
)
not any(hasattr(field, 'bulk_related_objects') for field in opts.private_fields)
)
)
@@ -228,27 +204,16 @@ class Collector:
"""
field_names = [field.name for field in fields]
conn_batch_size = max(
connections[self.using].ops.bulk_batch_size(field_names, objs), 1
)
connections[self.using].ops.bulk_batch_size(field_names, objs), 1)
if len(objs) > conn_batch_size:
return [
objs[i : i + conn_batch_size]
for i in range(0, len(objs), conn_batch_size)
]
return [objs[i:i + conn_batch_size]
for i in range(0, len(objs), conn_batch_size)]
else:
return [objs]
def collect(
self,
objs,
source=None,
nullable=False,
collect_related=True,
source_attr=None,
reverse_dependency=False,
keep_parents=False,
fail_on_restricted=True,
):
def collect(self, objs, source=None, nullable=False, collect_related=True,
source_attr=None, reverse_dependency=False, keep_parents=False,
fail_on_restricted=True):
"""
Add 'objs' to the collection of objects to be deleted as well as all
parent instances. 'objs' must be a homogeneous iterable collection of
@@ -275,9 +240,8 @@ class Collector:
if self.can_fast_delete(objs):
self.fast_deletes.append(objs)
return
new_objs = self.add(
objs, source, nullable, reverse_dependency=reverse_dependency
)
new_objs = self.add(objs, source, nullable,
reverse_dependency=reverse_dependency)
if not new_objs:
return
@@ -290,14 +254,11 @@ class Collector:
for ptr in concrete_model._meta.parents.values():
if ptr:
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
self.collect(
parent_objs,
source=model,
source_attr=ptr.remote_field.related_name,
collect_related=False,
reverse_dependency=True,
fail_on_restricted=False,
)
self.collect(parent_objs, source=model,
source_attr=ptr.remote_field.related_name,
collect_related=False,
reverse_dependency=True,
fail_on_restricted=False)
if not collect_related:
return
@@ -325,18 +286,11 @@ class Collector:
# relationships are select_related as interactions between both
# features are hard to get right. This should only happen in
# the rare cases where .related_objects is overridden anyway.
if not (
sub_objs.query.select_related
or self._has_signal_listeners(related_model)
):
referenced_fields = set(
chain.from_iterable(
(rf.attname for rf in rel.field.foreign_related_fields)
for rel in get_candidate_relations_to_delete(
related_model._meta
)
)
)
if not (sub_objs.query.select_related or self._has_signal_listeners(related_model)):
referenced_fields = set(chain.from_iterable(
(rf.attname for rf in rel.field.foreign_related_fields)
for rel in get_candidate_relations_to_delete(related_model._meta)
))
sub_objs = sub_objs.only(*tuple(referenced_fields))
if sub_objs:
try:
@@ -346,11 +300,10 @@ class Collector:
protected_objects[key] += error.protected_objects
if protected_objects:
raise ProtectedError(
"Cannot delete some instances of model %r because they are "
"referenced through protected foreign keys: %s."
% (
'Cannot delete some instances of model %r because they are '
'referenced through protected foreign keys: %s.' % (
model.__name__,
", ".join(protected_objects),
', '.join(protected_objects),
),
set(chain.from_iterable(protected_objects.values())),
)
@@ -360,12 +313,10 @@ class Collector:
sub_objs = self.related_objects(related_model, related_fields, batch)
self.fast_deletes.append(sub_objs)
for field in model._meta.private_fields:
if hasattr(field, "bulk_related_objects"):
if hasattr(field, 'bulk_related_objects'):
# It's something like generic foreign key.
sub_objs = field.bulk_related_objects(new_objs, self.using)
self.collect(
sub_objs, source=model, nullable=True, fail_on_restricted=False
)
self.collect(sub_objs, source=model, nullable=True, fail_on_restricted=False)
if fail_on_restricted:
# Raise an error if collected restricted objects (RESTRICT) aren't
@@ -383,12 +334,11 @@ class Collector:
restricted_objects[key] += objs
if restricted_objects:
raise RestrictedError(
"Cannot delete some instances of model %r because "
"they are referenced through restricted foreign keys: "
"%s."
% (
'Cannot delete some instances of model %r because '
'they are referenced through restricted foreign keys: '
'%s.' % (
model.__name__,
", ".join(restricted_objects),
', '.join(restricted_objects),
),
set(chain.from_iterable(restricted_objects.values())),
)
@@ -397,10 +347,10 @@ class Collector:
"""
Get a QuerySet of the related model to objs via related fields.
"""
predicate = query_utils.Q(
*((f"{related_field.name}__in", objs) for related_field in related_fields),
_connector=query_utils.Q.OR,
)
predicate = reduce(operator.or_, (
query_utils.Q(**{'%s__in' % related_field.name: objs})
for related_field in related_fields
))
return related_model._base_manager.using(self.using).filter(predicate)
def instances_with_model(self):
@@ -443,9 +393,7 @@ class Collector:
instance = list(instances)[0]
if self.can_fast_delete(instance):
with transaction.mark_for_rollback_on_error(self.using):
count = sql.DeleteQuery(model).delete_batch(
[instance.pk], self.using
)
count = sql.DeleteQuery(model).delete_batch([instance.pk], self.using)
setattr(instance, model._meta.pk.attname, None)
return count, {model._meta.label: count}
@@ -467,9 +415,8 @@ class Collector:
for model, instances_for_fieldvalues in self.field_updates.items():
for (field, value), instances in instances_for_fieldvalues.items():
query = sql.UpdateQuery(model)
query.update_batch(
[obj.pk for obj in instances], {field.name: value}, self.using
)
query.update_batch([obj.pk for obj in instances],
{field.name: value}, self.using)
# reverse instance collections
for instances in self.data.values():
@@ -1,9 +1,8 @@
import enum
from types import DynamicClassAttribute
from django.utils.functional import Promise
__all__ = ["Choices", "IntegerChoices", "TextChoices"]
__all__ = ['Choices', 'IntegerChoices', 'TextChoices']
class ChoicesMeta(enum.EnumMeta):
@@ -14,21 +13,25 @@ class ChoicesMeta(enum.EnumMeta):
for key in classdict._member_names:
value = classdict[key]
if (
isinstance(value, (list, tuple))
and len(value) > 1
and isinstance(value[-1], (Promise, str))
isinstance(value, (list, tuple)) and
len(value) > 1 and
isinstance(value[-1], (Promise, str))
):
*value, label = value
value = tuple(value)
else:
label = key.replace("_", " ").title()
label = key.replace('_', ' ').title()
labels.append(label)
# Use dict.__setitem__() to suppress defenses against double
# assignment in enum's classdict.
dict.__setitem__(classdict, key, value)
cls = super().__new__(metacls, classname, bases, classdict, **kwds)
for member, label in zip(cls.__members__.values(), labels):
member._label_ = label
cls._value2label_map_ = dict(zip(cls._value2member_map_, labels))
# Add a label property to instances of enum which uses the enum member
# that is passed in as "self" as the value to use when looking up the
# label in the choices.
cls.label = property(lambda self: cls._value2label_map_.get(self.value))
cls.do_not_call_in_templates = True
return enum.unique(cls)
def __contains__(cls, member):
@@ -39,12 +42,12 @@ class ChoicesMeta(enum.EnumMeta):
@property
def names(cls):
empty = ["__empty__"] if hasattr(cls, "__empty__") else []
empty = ['__empty__'] if hasattr(cls, '__empty__') else []
return empty + [member.name for member in cls]
@property
def choices(cls):
empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else []
empty = [(None, cls.__empty__)] if hasattr(cls, '__empty__') else []
return empty + [(member.value, member.label) for member in cls]
@property
@@ -59,14 +62,6 @@ class ChoicesMeta(enum.EnumMeta):
class Choices(enum.Enum, metaclass=ChoicesMeta):
"""Class for creating enumerated choices."""
@DynamicClassAttribute
def label(self):
return self._label_
@property
def do_not_call_in_templates(self):
return True
def __str__(self):
"""
Use value when cast to str, so that Choices set as model instance
@@ -74,14 +69,9 @@ class Choices(enum.Enum, metaclass=ChoicesMeta):
"""
return str(self.value)
# A similar format was proposed for Python 3.10.
def __repr__(self):
return f"{self.__class__.__qualname__}.{self._name_}"
class IntegerChoices(int, Choices):
"""Class for creating enumerated integer choices."""
pass
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -24,7 +24,7 @@ class FieldFile(File):
def __eq__(self, other):
# Older code may be expecting FileField values to be simple strings.
# By overriding the == operator, it can remain backwards compatibility.
if hasattr(other, "name"):
if hasattr(other, 'name'):
return self.name == other.name
return self.name == other
@@ -37,14 +37,12 @@ class FieldFile(File):
def _require_file(self):
if not self:
raise ValueError(
"The '%s' attribute has no file associated with it." % self.field.name
)
raise ValueError("The '%s' attribute has no file associated with it." % self.field.name)
def _get_file(self):
self._require_file()
if getattr(self, "_file", None) is None:
self._file = self.storage.open(self.name, "rb")
if getattr(self, '_file', None) is None:
self._file = self.storage.open(self.name, 'rb')
return self._file
def _set_file(self, file):
@@ -72,14 +70,13 @@ class FieldFile(File):
return self.file.size
return self.storage.size(self.name)
def open(self, mode="rb"):
def open(self, mode='rb'):
self._require_file()
if getattr(self, "_file", None) is None:
if getattr(self, '_file', None) is None:
self.file = self.storage.open(self.name, mode)
else:
self.file.open(mode)
return self
# open() doesn't alter the file's contents, but it does reset the pointer
open.alters_data = True
@@ -96,7 +93,6 @@ class FieldFile(File):
# Save the object because it has changed, unless save is False
if save:
self.instance.save()
save.alters_data = True
def delete(self, save=True):
@@ -104,7 +100,7 @@ class FieldFile(File):
return
# Only close the file if it's already open, which we know by the
# presence of self._file
if hasattr(self, "_file"):
if hasattr(self, '_file'):
self.close()
del self.file
@@ -116,16 +112,15 @@ class FieldFile(File):
if save:
self.instance.save()
delete.alters_data = True
@property
def closed(self):
file = getattr(self, "_file", None)
file = getattr(self, '_file', None)
return file is None or file.closed
def close(self):
file = getattr(self, "_file", None)
file = getattr(self, '_file', None)
if file is not None:
file.close()
@@ -134,12 +129,12 @@ class FieldFile(File):
# the file's name. Everything else will be restored later, by
# FileDescriptor below.
return {
"name": self.name,
"closed": False,
"_committed": True,
"_file": None,
"instance": self.instance,
"field": self.field,
'name': self.name,
'closed': False,
'_committed': True,
'_file': None,
'instance': self.instance,
'field': self.field,
}
def __setstate__(self, state):
@@ -161,7 +156,6 @@ class FileDescriptor(DeferredAttribute):
>>> with open('/path/to/hello.world') as f:
... instance.file = File(f)
"""
def __get__(self, instance, cls=None):
if instance is None:
return self
@@ -204,7 +198,7 @@ class FileDescriptor(DeferredAttribute):
# Finally, because of the (some would say boneheaded) way pickle works,
# the underlying FieldFile might not actually itself have an associated
# file. So we need to reset the details of the FieldFile in those cases.
elif isinstance(file, FieldFile) and not hasattr(file, "field"):
elif isinstance(file, FieldFile) and not hasattr(file, 'field'):
file.instance = instance
file.field = self.field
file.storage = self.field.storage
@@ -231,10 +225,8 @@ class FileField(Field):
description = _("File")
def __init__(
self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
):
self._primary_key_set_explicitly = "primary_key" in kwargs
def __init__(self, verbose_name=None, name=None, upload_to='', storage=None, **kwargs):
self._primary_key_set_explicitly = 'primary_key' in kwargs
self.storage = storage or default_storage
if callable(self.storage):
@@ -244,15 +236,11 @@ class FileField(Field):
if not isinstance(self.storage, Storage):
raise TypeError(
"%s.storage must be a subclass/instance of %s.%s"
% (
self.__class__.__qualname__,
Storage.__module__,
Storage.__qualname__,
)
% (self.__class__.__qualname__, Storage.__module__, Storage.__qualname__)
)
self.upload_to = upload_to
kwargs.setdefault("max_length", 100)
kwargs.setdefault('max_length', 100)
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
@@ -266,24 +254,23 @@ class FileField(Field):
if self._primary_key_set_explicitly:
return [
checks.Error(
"'primary_key' is not a valid argument for a %s."
% self.__class__.__name__,
"'primary_key' is not a valid argument for a %s." % self.__class__.__name__,
obj=self,
id="fields.E201",
id='fields.E201',
)
]
else:
return []
def _check_upload_to(self):
if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
if isinstance(self.upload_to, str) and self.upload_to.startswith('/'):
return [
checks.Error(
"%s's 'upload_to' argument must be a relative path, not an "
"absolute path." % self.__class__.__name__,
obj=self,
id="fields.E202",
hint="Remove the leading slash.",
id='fields.E202',
hint='Remove the leading slash.',
)
]
else:
@@ -293,9 +280,9 @@ class FileField(Field):
name, path, args, kwargs = super().deconstruct()
if kwargs.get("max_length") == 100:
del kwargs["max_length"]
kwargs["upload_to"] = self.upload_to
kwargs['upload_to'] = self.upload_to
if self.storage is not default_storage:
kwargs["storage"] = getattr(self, "_storage_callable", self.storage)
kwargs['storage'] = getattr(self, '_storage_callable', self.storage)
return name, path, args, kwargs
def get_internal_type(self):
@@ -303,8 +290,7 @@ class FileField(Field):
def get_prep_value(self, value):
value = super().get_prep_value(value)
# Need to convert File objects provided via a form to string for
# database insertion.
# Need to convert File objects provided via a form to string for database insertion
if value is None:
return None
return str(value)
@@ -343,16 +329,14 @@ class FileField(Field):
if data is not None:
# This value will be converted to str and stored in the
# database, so leaving False as-is is not acceptable.
setattr(instance, self.name, data or "")
setattr(instance, self.name, data or '')
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.FileField,
"max_length": self.max_length,
**kwargs,
}
)
return super().formfield(**{
'form_class': forms.FileField,
'max_length': self.max_length,
**kwargs,
})
class ImageFileDescriptor(FileDescriptor):
@@ -360,7 +344,6 @@ class ImageFileDescriptor(FileDescriptor):
Just like the FileDescriptor, but for ImageFields. The only difference is
assigning the width/height to the width_field/height_field, if appropriate.
"""
def __set__(self, instance, value):
previous_file = instance.__dict__.get(self.field.attname)
super().__set__(instance, value)
@@ -381,7 +364,7 @@ class ImageFileDescriptor(FileDescriptor):
class ImageFieldFile(ImageFile, FieldFile):
def delete(self, save=True):
# Clear the image dimensions cache
if hasattr(self, "_dimensions_cache"):
if hasattr(self, '_dimensions_cache'):
del self._dimensions_cache
super().delete(save)
@@ -391,14 +374,7 @@ class ImageField(FileField):
descriptor_class = ImageFileDescriptor
description = _("Image")
def __init__(
self,
verbose_name=None,
name=None,
width_field=None,
height_field=None,
**kwargs,
):
def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs):
self.width_field, self.height_field = width_field, height_field
super().__init__(verbose_name, name, **kwargs)
@@ -414,13 +390,11 @@ class ImageField(FileField):
except ImportError:
return [
checks.Error(
"Cannot use ImageField because Pillow is not installed.",
hint=(
"Get Pillow at https://pypi.org/project/Pillow/ "
'or run command "python -m pip install Pillow".'
),
'Cannot use ImageField because Pillow is not installed.',
hint=('Get Pillow at https://pypi.org/project/Pillow/ '
'or run command "python -m pip install Pillow".'),
obj=self,
id="fields.E210",
id='fields.E210',
)
]
else:
@@ -429,9 +403,9 @@ class ImageField(FileField):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.width_field:
kwargs["width_field"] = self.width_field
kwargs['width_field'] = self.width_field
if self.height_field:
kwargs["height_field"] = self.height_field
kwargs['height_field'] = self.height_field
return name, path, args, kwargs
def contribute_to_class(self, cls, name, **kwargs):
@@ -471,9 +445,9 @@ class ImageField(FileField):
if not file and not force:
return
dimension_fields_filled = not (
(self.width_field and not getattr(instance, self.width_field))
or (self.height_field and not getattr(instance, self.height_field))
dimension_fields_filled = not(
(self.width_field and not getattr(instance, self.width_field)) or
(self.height_field and not getattr(instance, self.height_field))
)
# When both dimension fields have values, we are most likely loading
# data from the database or updating an image field that already had
@@ -501,9 +475,7 @@ class ImageField(FileField):
setattr(instance, self.height_field, height)
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.ImageField,
**kwargs,
}
)
return super().formfield(**{
'form_class': forms.ImageField,
**kwargs,
})
@@ -10,36 +10,32 @@ from django.utils.translation import gettext_lazy as _
from . import Field
from .mixins import CheckFieldDefaultMixin
__all__ = ["JSONField"]
__all__ = ['JSONField']
class JSONField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
description = _("A JSON object")
description = _('A JSON object')
default_error_messages = {
"invalid": _("Value must be valid JSON."),
'invalid': _('Value must be valid JSON.'),
}
_default_hint = ("dict", "{}")
_default_hint = ('dict', '{}')
def __init__(
self,
verbose_name=None,
name=None,
encoder=None,
decoder=None,
self, verbose_name=None, name=None, encoder=None, decoder=None,
**kwargs,
):
if encoder and not callable(encoder):
raise ValueError("The encoder parameter must be a callable object.")
raise ValueError('The encoder parameter must be a callable object.')
if decoder and not callable(decoder):
raise ValueError("The decoder parameter must be a callable object.")
raise ValueError('The decoder parameter must be a callable object.')
self.encoder = encoder
self.decoder = decoder
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
errors = super().check(**kwargs)
databases = kwargs.get("databases") or []
databases = kwargs.get('databases') or []
errors.extend(self._check_supported(databases))
return errors
@@ -50,19 +46,20 @@ class JSONField(CheckFieldDefaultMixin, Field):
continue
connection = connections[db]
if (
self.model._meta.required_db_vendor
and self.model._meta.required_db_vendor != connection.vendor
self.model._meta.required_db_vendor and
self.model._meta.required_db_vendor != connection.vendor
):
continue
if not (
"supports_json_field" in self.model._meta.required_db_features
or connection.features.supports_json_field
'supports_json_field' in self.model._meta.required_db_features or
connection.features.supports_json_field
):
errors.append(
checks.Error(
"%s does not support JSONFields." % connection.display_name,
'%s does not support JSONFields.'
% connection.display_name,
obj=self.model,
id="fields.E180",
id='fields.E180',
)
)
return errors
@@ -70,9 +67,9 @@ class JSONField(CheckFieldDefaultMixin, Field):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.encoder is not None:
kwargs["encoder"] = self.encoder
kwargs['encoder'] = self.encoder
if self.decoder is not None:
kwargs["decoder"] = self.decoder
kwargs['decoder'] = self.decoder
return name, path, args, kwargs
def from_db_value(self, value, expression, connection):
@@ -88,7 +85,7 @@ class JSONField(CheckFieldDefaultMixin, Field):
return value
def get_internal_type(self):
return "JSONField"
return 'JSONField'
def get_prep_value(self, value):
if value is None:
@@ -107,66 +104,64 @@ class JSONField(CheckFieldDefaultMixin, Field):
json.dumps(value, cls=self.encoder)
except TypeError:
raise exceptions.ValidationError(
self.error_messages["invalid"],
code="invalid",
params={"value": value},
self.error_messages['invalid'],
code='invalid',
params={'value': value},
)
def value_to_string(self, obj):
return self.value_from_object(obj)
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.JSONField,
"encoder": self.encoder,
"decoder": self.decoder,
**kwargs,
}
)
return super().formfield(**{
'form_class': forms.JSONField,
'encoder': self.encoder,
'decoder': self.decoder,
**kwargs,
})
def compile_json_path(key_transforms, include_root=True):
path = ["$"] if include_root else []
path = ['$'] if include_root else []
for key_transform in key_transforms:
try:
num = int(key_transform)
except ValueError: # non-integer
path.append(".")
path.append('.')
path.append(json.dumps(key_transform))
else:
path.append("[%s]" % num)
return "".join(path)
path.append('[%s]' % num)
return ''.join(path)
class DataContains(PostgresOperatorLookup):
lookup_name = "contains"
postgres_operator = "@>"
lookup_name = 'contains'
postgres_operator = '@>'
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
"contains lookup is not supported on this database backend."
'contains lookup is not supported on this database backend.'
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params
class ContainedBy(PostgresOperatorLookup):
lookup_name = "contained_by"
postgres_operator = "<@"
lookup_name = 'contained_by'
postgres_operator = '<@'
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
"contained_by lookup is not supported on this database backend."
'contained_by lookup is not supported on this database backend.'
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(rhs_params) + tuple(lhs_params)
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params
class HasKeyLookup(PostgresOperatorLookup):
@@ -175,13 +170,11 @@ class HasKeyLookup(PostgresOperatorLookup):
def as_sql(self, compiler, connection, template=None):
# Process JSON path from the left-hand side.
if isinstance(self.lhs, KeyTransform):
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
compiler, connection
)
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection)
lhs_json_path = compile_json_path(lhs_key_transforms)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
lhs_json_path = "$"
lhs_json_path = '$'
sql = template % lhs
# Process JSON path from the right-hand side.
rhs = self.rhs
@@ -193,27 +186,20 @@ class HasKeyLookup(PostgresOperatorLookup):
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
else:
rhs_key_transforms = [key]
rhs_params.append(
"%s%s"
% (
lhs_json_path,
compile_json_path(rhs_key_transforms, include_root=False),
)
)
rhs_params.append('%s%s' % (
lhs_json_path,
compile_json_path(rhs_key_transforms, include_root=False),
))
# Add condition for each key.
if self.logical_operator:
sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params))
return sql, tuple(lhs_params) + tuple(rhs_params)
def as_mysql(self, compiler, connection):
return self.as_sql(
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
)
return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)")
def as_oracle(self, compiler, connection):
sql, params = self.as_sql(
compiler, connection, template="JSON_EXISTS(%s, '%%s')"
)
sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')")
# Add paths directly into SQL because path expressions cannot be passed
# as bind variables on Oracle.
return sql % tuple(params), []
@@ -227,83 +213,64 @@ class HasKeyLookup(PostgresOperatorLookup):
return super().as_postgresql(compiler, connection)
def as_sqlite(self, compiler, connection):
return self.as_sql(
compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
)
return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL')
class HasKey(HasKeyLookup):
lookup_name = "has_key"
postgres_operator = "?"
lookup_name = 'has_key'
postgres_operator = '?'
prepare_rhs = False
class HasKeys(HasKeyLookup):
lookup_name = "has_keys"
postgres_operator = "?&"
logical_operator = " AND "
lookup_name = 'has_keys'
postgres_operator = '?&'
logical_operator = ' AND '
def get_prep_lookup(self):
return [str(item) for item in self.rhs]
class HasAnyKeys(HasKeys):
lookup_name = "has_any_keys"
postgres_operator = "?|"
logical_operator = " OR "
class CaseInsensitiveMixin:
"""
Mixin to allow case-insensitive comparison of JSON values on MySQL.
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
Because utf8mb4_bin is a binary collation, comparison of JSON values is
case-sensitive.
"""
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == "mysql":
return "LOWER(%s)" % lhs, lhs_params
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == "mysql":
return "LOWER(%s)" % rhs, rhs_params
return rhs, rhs_params
lookup_name = 'has_any_keys'
postgres_operator = '?|'
logical_operator = ' OR '
class JSONExact(lookups.Exact):
can_use_none_as_rhs = True
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == 'sqlite':
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs == '%s' and rhs_params == [None]:
# Use JSON_TYPE instead of JSON_EXTRACT for NULLs.
lhs = "JSON_TYPE(%s, '$')" % lhs
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null.
if rhs == "%s" and rhs_params == [None]:
rhs_params = ["null"]
if connection.vendor == "mysql":
if rhs == '%s' and rhs_params == [None]:
rhs_params = ['null']
if connection.vendor == 'mysql':
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
rhs = rhs % tuple(func)
return rhs, rhs_params
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
pass
JSONField.register_lookup(DataContains)
JSONField.register_lookup(ContainedBy)
JSONField.register_lookup(HasKey)
JSONField.register_lookup(HasKeys)
JSONField.register_lookup(HasAnyKeys)
JSONField.register_lookup(JSONExact)
JSONField.register_lookup(JSONIContains)
class KeyTransform(Transform):
postgres_operator = "->"
postgres_nested_operator = "#>"
postgres_operator = '->'
postgres_nested_operator = '#>'
def __init__(self, key_name, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -316,50 +283,44 @@ class KeyTransform(Transform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
lhs, params = compiler.compile(previous)
if connection.vendor == "oracle":
if connection.vendor == 'oracle':
# Escape string-formatting.
key_transforms = [key.replace("%", "%%") for key in key_transforms]
key_transforms = [key.replace('%', '%%') for key in key_transforms]
return lhs, params, key_transforms
def as_mysql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
def as_oracle(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return (
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
% ((lhs, json_path) * 2)
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" %
((lhs, json_path) * 2)
), tuple(params) * 2
def as_postgresql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
if len(key_transforms) > 1:
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator)
return sql, tuple(params) + (key_transforms,)
try:
lookup = int(self.key_name)
except ValueError:
lookup = self.key_name
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,)
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
datatype_values = ",".join(
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
)
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
class KeyTextTransform(KeyTransform):
postgres_operator = "->>"
postgres_nested_operator = "#>>"
postgres_operator = '->>'
postgres_nested_operator = '#>>'
class KeyTransformTextLookupMixin:
@@ -369,21 +330,39 @@ class KeyTransformTextLookupMixin:
key values to text and performing the lookup on the resulting
representation.
"""
def __init__(self, key_transform, *args, **kwargs):
if not isinstance(key_transform, KeyTransform):
raise TypeError(
"Transform should be an instance of KeyTransform in order to "
"use this lookup."
'Transform should be an instance of KeyTransform in order to '
'use this lookup.'
)
key_text_transform = KeyTextTransform(
key_transform.key_name,
*key_transform.source_expressions,
key_transform.key_name, *key_transform.source_expressions,
**key_transform.extra,
)
super().__init__(key_text_transform, *args, **kwargs)
class CaseInsensitiveMixin:
"""
Mixin to allow case-insensitive comparison of JSON values on MySQL.
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
Because utf8mb4_bin is a binary collation, comparison of JSON values is
case-sensitive.
"""
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == 'mysql':
return 'LOWER(%s)' % lhs, lhs_params
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == 'mysql':
return 'LOWER(%s)' % rhs, rhs_params
return rhs, rhs_params
class KeyTransformIsNull(lookups.IsNull):
# key__isnull=False is the same as has_key='key'
def as_oracle(self, compiler, connection):
@@ -395,12 +374,12 @@ class KeyTransformIsNull(lookups.IsNull):
return sql, params
# Column doesn't have a key or IS NULL.
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params)
def as_sqlite(self, compiler, connection):
template = "JSON_TYPE(%s, %%s) IS NULL"
template = 'JSON_TYPE(%s, %%s) IS NULL'
if not self.rhs:
template = "JSON_TYPE(%s, %%s) IS NOT NULL"
template = 'JSON_TYPE(%s, %%s) IS NOT NULL'
return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql(
compiler,
connection,
@@ -411,81 +390,75 @@ class KeyTransformIsNull(lookups.IsNull):
class KeyTransformIn(lookups.In):
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
compiler,
connection,
sql,
param,
compiler, connection, sql, param,
)
if (
not hasattr(param, "as_sql")
and not connection.features.has_native_json_field
not hasattr(param, 'as_sql') and
not connection.features.has_native_json_field
):
if connection.vendor == "oracle":
if connection.vendor == 'oracle':
value = json.loads(param)
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
if isinstance(value, (list, dict)):
sql = sql % "JSON_QUERY"
sql = sql % 'JSON_QUERY'
else:
sql = sql % "JSON_VALUE"
elif connection.vendor == "mysql" or (
connection.vendor == "sqlite"
and params[0] not in connection.ops.jsonfield_datatype_values
):
sql = sql % 'JSON_VALUE'
elif connection.vendor in {'sqlite', 'mysql'}:
sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
sql = "JSON_UNQUOTE(%s)" % sql
if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
sql = 'JSON_UNQUOTE(%s)' % sql
return sql, params
class KeyTransformExact(JSONExact):
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == 'sqlite':
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs == '%s' and rhs_params == ['null']:
lhs, *_ = self.lhs.preprocess_lhs(compiler, connection)
lhs = 'JSON_TYPE(%s, %%s)' % lhs
return lhs, lhs_params
def process_rhs(self, compiler, connection):
if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection)
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == "oracle":
if connection.vendor == 'oracle':
func = []
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
for value in rhs_params:
value = json.loads(value)
if isinstance(value, (list, dict)):
func.append(sql % "JSON_QUERY")
func.append(sql % 'JSON_QUERY')
else:
func.append(sql % "JSON_VALUE")
func.append(sql % 'JSON_VALUE')
rhs = rhs % tuple(func)
elif connection.vendor == "sqlite":
func = []
for value in rhs_params:
if value in connection.ops.jsonfield_datatype_values:
func.append("%s")
else:
func.append("JSON_EXTRACT(%s, '$')")
elif connection.vendor == 'sqlite':
func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params]
rhs = rhs % tuple(func)
return rhs, rhs_params
def as_oracle(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs_params == ["null"]:
if rhs_params == ['null']:
# Field has key and it's NULL.
has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name)
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True)
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
return (
"%s AND %s" % (has_key_sql, is_null_sql),
'%s AND %s' % (has_key_sql, is_null_sql),
tuple(has_key_params) + tuple(is_null_params),
)
return super().as_sql(compiler, connection)
class KeyTransformIExact(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
):
class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact):
pass
class KeyTransformIContains(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
):
class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains):
pass
@@ -493,9 +466,7 @@ class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
pass
class KeyTransformIStartsWith(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
):
class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith):
pass
@@ -503,9 +474,7 @@ class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
pass
class KeyTransformIEndsWith(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
):
class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith):
pass
@@ -513,9 +482,7 @@ class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
pass
class KeyTransformIRegex(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
):
class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex):
pass
@@ -562,6 +529,7 @@ KeyTransform.register_lookup(KeyTransformGte)
class KeyTransformFactory:
def __init__(self, key_name):
self.key_name = key_name
@@ -29,25 +29,22 @@ class FieldCacheMixin:
class CheckFieldDefaultMixin:
_default_hint = ("<valid default>", "<invalid default>")
_default_hint = ('<valid default>', '<invalid default>')
def _check_default(self):
if (
self.has_default()
and self.default is not None
and not callable(self.default)
):
if self.has_default() and self.default is not None and not callable(self.default):
return [
checks.Warning(
"%s default should be a callable instead of an instance "
"so that it's not shared between all field instances."
% (self.__class__.__name__,),
"so that it's not shared between all field instances." % (
self.__class__.__name__,
),
hint=(
"Use a callable instead, e.g., use `%s` instead of "
"`%s`." % self._default_hint
'Use a callable instead, e.g., use `%s` instead of '
'`%s`.' % self._default_hint
),
obj=self,
id="fields.E010",
id='fields.E010',
)
]
else:
@@ -13,6 +13,6 @@ class OrderWrt(fields.IntegerField):
"""
def __init__(self, *args, **kwargs):
kwargs["name"] = "_order"
kwargs["editable"] = False
kwargs['name'] = '_order'
kwargs['editable'] = False
super().__init__(*args, **kwargs)
File diff suppressed because it is too large Load Diff
@@ -74,9 +74,7 @@ from django.utils.functional import cached_property
class ForeignKeyDeferredAttribute(DeferredAttribute):
def __set__(self, instance, value):
if instance.__dict__.get(self.field.attname) != value and self.field.is_cached(
instance
):
if instance.__dict__.get(self.field.attname) != value and self.field.is_cached(instance):
self.field.delete_cached_value(instance)
instance.__dict__[self.field.attname] = value
@@ -103,16 +101,14 @@ class ForwardManyToOneDescriptor:
# related model might not be resolved yet; `self.field.model` might
# still be a string model reference.
return type(
"RelatedObjectDoesNotExist",
(self.field.remote_field.model.DoesNotExist, AttributeError),
{
"__module__": self.field.model.__module__,
"__qualname__": "%s.%s.RelatedObjectDoesNotExist"
% (
'RelatedObjectDoesNotExist',
(self.field.remote_field.model.DoesNotExist, AttributeError), {
'__module__': self.field.model.__module__,
'__qualname__': '%s.%s.RelatedObjectDoesNotExist' % (
self.field.model.__qualname__,
self.field.name,
),
},
}
)
def is_cached(self, instance):
@@ -139,12 +135,9 @@ class ForwardManyToOneDescriptor:
# The check for len(...) == 1 is a special case that allows the query
# to be join-less and smaller. Refs #21760.
if remote_field.is_hidden() or len(self.field.foreign_related_fields) == 1:
query = {
"%s__in"
% related_field.name: {instance_attr(inst)[0] for inst in instances}
}
query = {'%s__in' % related_field.name: {instance_attr(inst)[0] for inst in instances}}
else:
query = {"%s__in" % self.field.related_query_name(): instances}
query = {'%s__in' % self.field.related_query_name(): instances}
queryset = queryset.filter(**query)
# Since we're going to assign directly in the cache,
@@ -153,14 +146,7 @@ class ForwardManyToOneDescriptor:
for rel_obj in queryset:
instance = instances_dict[rel_obj_attr(rel_obj)]
remote_field.set_cached_value(rel_obj, instance)
return (
queryset,
rel_obj_attr,
instance_attr,
True,
self.field.get_cache_name(),
False,
)
return queryset, rel_obj_attr, instance_attr, True, self.field.get_cache_name(), False
def get_object(self, instance):
qs = self.get_queryset(instance=instance)
@@ -187,11 +173,7 @@ class ForwardManyToOneDescriptor:
rel_obj = self.field.get_cached_value(instance)
except KeyError:
has_value = None not in self.field.get_local_related_value(instance)
ancestor_link = (
instance._meta.get_ancestor_link(self.field.model)
if has_value
else None
)
ancestor_link = instance._meta.get_ancestor_link(self.field.model) if has_value else None
if ancestor_link and ancestor_link.is_cached(instance):
# An ancestor link will exist if this field is defined on a
# multi-table inheritance parent of the instance's class.
@@ -229,12 +211,9 @@ class ForwardManyToOneDescriptor:
- ``value`` is the ``parent`` instance on the right of the equal sign
"""
# An object must be an instance of the related class.
if value is not None and not isinstance(
value, self.field.remote_field.model._meta.concrete_model
):
if value is not None and not isinstance(value, self.field.remote_field.model._meta.concrete_model):
raise ValueError(
'Cannot assign "%r": "%s.%s" must be a "%s" instance.'
% (
'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
value,
instance._meta.object_name,
self.field.name,
@@ -243,18 +222,11 @@ class ForwardManyToOneDescriptor:
)
elif value is not None:
if instance._state.db is None:
instance._state.db = router.db_for_write(
instance.__class__, instance=value
)
instance._state.db = router.db_for_write(instance.__class__, instance=value)
if value._state.db is None:
value._state.db = router.db_for_write(
value.__class__, instance=instance
)
value._state.db = router.db_for_write(value.__class__, instance=instance)
if not router.allow_relation(value, instance):
raise ValueError(
'Cannot assign "%r": the current database router prevents this '
"relation." % value
)
raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)
remote_field = self.field.remote_field
# If we're setting the value of a OneToOneField to None, we need to clear
@@ -342,15 +314,12 @@ class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor):
opts = instance._meta
# Inherited primary key fields from this object's base classes.
inherited_pk_fields = [
field
for field in opts.concrete_fields
field for field in opts.concrete_fields
if field.primary_key and field.remote_field
]
for field in inherited_pk_fields:
rel_model_pk_name = field.remote_field.model._meta.pk.attname
raw_value = (
getattr(value, rel_model_pk_name) if value is not None else None
)
raw_value = getattr(value, rel_model_pk_name) if value is not None else None
setattr(instance, rel_model_pk_name, raw_value)
@@ -377,15 +346,13 @@ class ReverseOneToOneDescriptor:
# The exception isn't created at initialization time for the sake of
# consistency with `ForwardManyToOneDescriptor`.
return type(
"RelatedObjectDoesNotExist",
(self.related.related_model.DoesNotExist, AttributeError),
{
"__module__": self.related.model.__module__,
"__qualname__": "%s.%s.RelatedObjectDoesNotExist"
% (
'RelatedObjectDoesNotExist',
(self.related.related_model.DoesNotExist, AttributeError), {
'__module__': self.related.model.__module__,
'__qualname__': '%s.%s.RelatedObjectDoesNotExist' % (
self.related.model.__qualname__,
self.related.name,
),
)
},
)
@@ -403,7 +370,7 @@ class ReverseOneToOneDescriptor:
rel_obj_attr = self.related.field.get_local_related_value
instance_attr = self.related.field.get_foreign_related_value
instances_dict = {instance_attr(inst): inst for inst in instances}
query = {"%s__in" % self.related.field.name: instances}
query = {'%s__in' % self.related.field.name: instances}
queryset = queryset.filter(**query)
# Since we're going to assign directly in the cache,
@@ -411,14 +378,7 @@ class ReverseOneToOneDescriptor:
for rel_obj in queryset:
instance = instances_dict[rel_obj_attr(rel_obj)]
self.related.field.set_cached_value(rel_obj, instance)
return (
queryset,
rel_obj_attr,
instance_attr,
True,
self.related.get_cache_name(),
False,
)
return queryset, rel_obj_attr, instance_attr, True, self.related.get_cache_name(), False
def __get__(self, instance, cls=None):
"""
@@ -459,8 +419,10 @@ class ReverseOneToOneDescriptor:
if rel_obj is None:
raise self.RelatedObjectDoesNotExist(
"%s has no %s."
% (instance.__class__.__name__, self.related.get_accessor_name())
"%s has no %s." % (
instance.__class__.__name__,
self.related.get_accessor_name()
)
)
else:
return rel_obj
@@ -496,8 +458,7 @@ class ReverseOneToOneDescriptor:
elif not isinstance(value, self.related.related_model):
# An object must be an instance of the related class.
raise ValueError(
'Cannot assign "%r": "%s.%s" must be a "%s" instance.'
% (
'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
value,
instance._meta.object_name,
self.related.get_accessor_name(),
@@ -506,25 +467,14 @@ class ReverseOneToOneDescriptor:
)
else:
if instance._state.db is None:
instance._state.db = router.db_for_write(
instance.__class__, instance=value
)
instance._state.db = router.db_for_write(instance.__class__, instance=value)
if value._state.db is None:
value._state.db = router.db_for_write(
value.__class__, instance=instance
)
value._state.db = router.db_for_write(value.__class__, instance=instance)
if not router.allow_relation(value, instance):
raise ValueError(
'Cannot assign "%r": the current database router prevents this '
"relation." % value
)
raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)
related_pk = tuple(
getattr(instance, field.attname)
for field in self.related.field.foreign_related_fields
)
# Set the value of the related field to the value of the related
# object's related field.
related_pk = tuple(getattr(instance, field.attname) for field in self.related.field.foreign_related_fields)
# Set the value of the related field to the value of the related object's related field
for index, field in enumerate(self.related.field.local_related_fields):
setattr(value, field.attname, related_pk[index])
@@ -587,13 +537,13 @@ class ReverseManyToOneDescriptor:
def _get_set_deprecation_msg_params(self):
return (
"reverse side of a related set",
'reverse side of a related set',
self.rel.get_accessor_name(),
)
def __set__(self, instance, value):
raise TypeError(
"Direct assignment to the %s is prohibited. Use %s.set() instead."
'Direct assignment to the %s is prohibited. Use %s.set() instead.'
% self._get_set_deprecation_msg_params(),
)
@@ -620,7 +570,6 @@ def create_reverse_many_to_one_manager(superclass, rel):
manager = getattr(self.model, manager)
manager_class = create_reverse_many_to_one_manager(manager.__class__, rel)
return manager_class(self.instance)
do_not_call_in_templates = True
def _apply_rel_filters(self, queryset):
@@ -628,9 +577,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
Filter the queryset for the instance this manager is bound to.
"""
db = self._db or router.db_for_read(self.model, instance=self.instance)
empty_strings_as_null = connections[
db
].features.interprets_empty_strings_as_nulls
empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
queryset._add_hints(instance=self.instance)
if self._db:
queryset = queryset.using(self._db)
@@ -638,7 +585,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
queryset = queryset.filter(**self.core_filters)
for field in self.field.foreign_related_fields:
val = getattr(self.instance, field.attname)
if val is None or (val == "" and empty_strings_as_null):
if val is None or (val == '' and empty_strings_as_null):
return queryset.none()
if self.field.many_to_one:
# Guard against field-like objects such as GenericRelation
@@ -650,34 +597,24 @@ def create_reverse_many_to_one_manager(superclass, rel):
except FieldError:
# The relationship has multiple target fields. Use a tuple
# for related object id.
rel_obj_id = tuple(
[
getattr(self.instance, target_field.attname)
for target_field in self.field.get_path_info()[
-1
].target_fields
]
)
rel_obj_id = tuple([
getattr(self.instance, target_field.attname)
for target_field in self.field.get_path_info()[-1].target_fields
])
else:
rel_obj_id = getattr(self.instance, target_field.attname)
queryset._known_related_objects = {
self.field: {rel_obj_id: self.instance}
}
queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
return queryset
def _remove_prefetched_objects(self):
try:
self.instance._prefetched_objects_cache.pop(
self.field.remote_field.get_cache_name()
)
self.instance._prefetched_objects_cache.pop(self.field.remote_field.get_cache_name())
except (AttributeError, KeyError):
pass # nothing to clear from cache
def get_queryset(self):
try:
return self.instance._prefetched_objects_cache[
self.field.remote_field.get_cache_name()
]
return self.instance._prefetched_objects_cache[self.field.remote_field.get_cache_name()]
except (AttributeError, KeyError):
queryset = super().get_queryset()
return self._apply_rel_filters(queryset)
@@ -692,7 +629,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
rel_obj_attr = self.field.get_local_related_value
instance_attr = self.field.get_foreign_related_value
instances_dict = {instance_attr(inst): inst for inst in instances}
query = {"%s__in" % self.field.name: instances}
query = {'%s__in' % self.field.name: instances}
queryset = queryset.filter(**query)
# Since we just bypassed this class' get_queryset(), we must manage
@@ -709,13 +646,9 @@ def create_reverse_many_to_one_manager(superclass, rel):
def check_and_update_obj(obj):
if not isinstance(obj, self.model):
raise TypeError(
"'%s' instance expected, got %r"
% (
self.model._meta.object_name,
obj,
)
)
raise TypeError("'%s' instance expected, got %r" % (
self.model._meta.object_name, obj,
))
setattr(obj, self.field.name, self.instance)
if bulk:
@@ -728,44 +661,36 @@ def create_reverse_many_to_one_manager(superclass, rel):
"the object first." % obj
)
pks.append(obj.pk)
self.model._base_manager.using(db).filter(pk__in=pks).update(
**{
self.field.name: self.instance,
}
)
self.model._base_manager.using(db).filter(pk__in=pks).update(**{
self.field.name: self.instance,
})
else:
with transaction.atomic(using=db, savepoint=False):
for obj in objs:
check_and_update_obj(obj)
obj.save()
add.alters_data = True
def create(self, **kwargs):
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).create(**kwargs)
create.alters_data = True
def get_or_create(self, **kwargs):
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
get_or_create.alters_data = True
def update_or_create(self, **kwargs):
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).update_or_create(**kwargs)
update_or_create.alters_data = True
# remove() and clear() are only provided if the ForeignKey can have a
# value of null.
# remove() and clear() are only provided if the ForeignKey can have a value of null.
if rel.field.null:
def remove(self, *objs, bulk=True):
if not objs:
return
@@ -773,13 +698,9 @@ def create_reverse_many_to_one_manager(superclass, rel):
old_ids = set()
for obj in objs:
if not isinstance(obj, self.model):
raise TypeError(
"'%s' instance expected, got %r"
% (
self.model._meta.object_name,
obj,
)
)
raise TypeError("'%s' instance expected, got %r" % (
self.model._meta.object_name, obj,
))
# Is obj actually part of this descriptor set?
if self.field.get_local_related_value(obj) == val:
old_ids.add(obj.pk)
@@ -788,12 +709,10 @@ def create_reverse_many_to_one_manager(superclass, rel):
"%r is not related to %r." % (obj, self.instance)
)
self._clear(self.filter(pk__in=old_ids), bulk)
remove.alters_data = True
def clear(self, *, bulk=True):
self._clear(self, bulk)
clear.alters_data = True
def _clear(self, queryset, bulk):
@@ -808,7 +727,6 @@ def create_reverse_many_to_one_manager(superclass, rel):
for obj in queryset:
setattr(obj, self.field.name, None)
obj.save(update_fields=[self.field.name])
_clear.alters_data = True
def set(self, objs, *, bulk=True, clear=False):
@@ -835,7 +753,6 @@ def create_reverse_many_to_one_manager(superclass, rel):
self.add(*new_objs, bulk=bulk)
else:
self.add(*objs, bulk=bulk)
set.alters_data = True
return RelatedManager
@@ -882,8 +799,7 @@ class ManyToManyDescriptor(ReverseManyToOneDescriptor):
def _get_set_deprecation_msg_params(self):
return (
"%s side of a many-to-many set"
% ("reverse" if self.reverse else "forward"),
'%s side of a many-to-many set' % ('reverse' if self.reverse else 'forward'),
self.rel.get_accessor_name() if self.reverse else self.field.name,
)
@@ -926,51 +842,42 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
self.core_filters = {}
self.pk_field_names = {}
for lh_field, rh_field in self.source_field.related_fields:
core_filter_key = "%s__%s" % (self.query_field_name, rh_field.name)
core_filter_key = '%s__%s' % (self.query_field_name, rh_field.name)
self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)
self.pk_field_names[lh_field.name] = rh_field.name
self.related_val = self.source_field.get_foreign_related_value(instance)
if None in self.related_val:
raise ValueError(
'"%r" needs to have a value for field "%s" before '
"this many-to-many relationship can be used."
% (instance, self.pk_field_names[self.source_field_name])
)
raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' %
(instance, self.pk_field_names[self.source_field_name]))
# Even if this relation is not to pk, we require still pk value.
# The wish is that the instance has been already saved to DB,
# although having a pk value isn't a guarantee of that.
if instance.pk is None:
raise ValueError(
"%r instance needs to have a primary key value before "
"a many-to-many relationship can be used."
% instance.__class__.__name__
)
raise ValueError("%r instance needs to have a primary key value before "
"a many-to-many relationship can be used." %
instance.__class__.__name__)
def __call__(self, *, manager):
manager = getattr(self.model, manager)
manager_class = create_forward_many_to_many_manager(
manager.__class__, rel, reverse
)
manager_class = create_forward_many_to_many_manager(manager.__class__, rel, reverse)
return manager_class(instance=self.instance)
do_not_call_in_templates = True
def _build_remove_filters(self, removed_vals):
filters = Q((self.source_field_name, self.related_val))
filters = Q(**{self.source_field_name: self.related_val})
# No need to add a subquery condition if removed_vals is a QuerySet without
# filters.
removed_vals_filters = (
not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
)
removed_vals_filters = (not isinstance(removed_vals, QuerySet) or
removed_vals._has_filters())
if removed_vals_filters:
filters &= Q((f"{self.target_field_name}__in", removed_vals))
filters &= Q(**{'%s__in' % self.target_field_name: removed_vals})
if self.symmetrical:
symmetrical_filters = Q((self.target_field_name, self.related_val))
symmetrical_filters = Q(**{self.target_field_name: self.related_val})
if removed_vals_filters:
symmetrical_filters &= Q(
(f"{self.source_field_name}__in", removed_vals)
)
**{'%s__in' % self.source_field_name: removed_vals})
filters |= symmetrical_filters
return filters
@@ -1004,7 +911,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db)
query = {"%s__in" % self.query_field_name: instances}
query = {'%s__in' % self.query_field_name: instances}
queryset = queryset._next_is_sticky().filter(**query)
# M2M: need to annotate the query in order to get the primary model
@@ -1018,18 +925,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
join_table = fk.model._meta.db_table
connection = connections[queryset.db]
qn = connection.ops.quote_name
queryset = queryset.extra(
select={
"_prefetch_related_val_%s"
% f.attname: "%s.%s"
% (qn(join_table), qn(f.column))
for f in fk.local_related_fields
}
)
queryset = queryset.extra(select={
'_prefetch_related_val_%s' % f.attname:
'%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
return (
queryset,
lambda result: tuple(
getattr(result, "_prefetch_related_val_%s" % f.attname)
getattr(result, '_prefetch_related_val_%s' % f.attname)
for f in fk.local_related_fields
),
lambda inst: tuple(
@@ -1046,9 +948,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
db = router.db_for_write(self.through, instance=self.instance)
with transaction.atomic(using=db, savepoint=False):
self._add_items(
self.source_field_name,
self.target_field_name,
*objs,
self.source_field_name, self.target_field_name, *objs,
through_defaults=through_defaults,
)
# If this is a symmetrical m2m relation to self, add the mirror
@@ -1060,41 +960,30 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
*objs,
through_defaults=through_defaults,
)
add.alters_data = True
def remove(self, *objs):
self._remove_prefetched_objects()
self._remove_items(self.source_field_name, self.target_field_name, *objs)
remove.alters_data = True
def clear(self):
db = router.db_for_write(self.through, instance=self.instance)
with transaction.atomic(using=db, savepoint=False):
signals.m2m_changed.send(
sender=self.through,
action="pre_clear",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=None,
using=db,
sender=self.through, action="pre_clear",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db,
)
self._remove_prefetched_objects()
filters = self._build_remove_filters(super().get_queryset().using(db))
self.through._default_manager.using(db).filter(filters).delete()
signals.m2m_changed.send(
sender=self.through,
action="post_clear",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=None,
using=db,
sender=self.through, action="post_clear",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db,
)
clear.alters_data = True
def set(self, objs, *, clear=False, through_defaults=None):
@@ -1108,11 +997,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
self.clear()
self.add(*objs, through_defaults=through_defaults)
else:
old_ids = set(
self.using(db).values_list(
self.target_field.target_field.attname, flat=True
)
)
old_ids = set(self.using(db).values_list(self.target_field.target_field.attname, flat=True))
new_objs = []
for obj in objs:
@@ -1128,7 +1013,6 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
self.remove(*old_ids)
self.add(*new_objs, through_defaults=through_defaults)
set.alters_data = True
def create(self, *, through_defaults=None, **kwargs):
@@ -1136,33 +1020,26 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs)
self.add(new_obj, through_defaults=through_defaults)
return new_obj
create.alters_data = True
def get_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(
**kwargs
)
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(**kwargs)
# We only need to add() if created because if we got an object back
# from get() then the relationship already exists.
if created:
self.add(obj, through_defaults=through_defaults)
return obj, created
get_or_create.alters_data = True
def update_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super(
ManyRelatedManager, self.db_manager(db)
).update_or_create(**kwargs)
obj, created = super(ManyRelatedManager, self.db_manager(db)).update_or_create(**kwargs)
# We only need to add() if created because if we got an object back
# from get() then the relationship already exists.
if created:
self.add(obj, through_defaults=through_defaults)
return obj, created
update_or_create.alters_data = True
def _get_target_ids(self, target_field_name, objs):
@@ -1170,7 +1047,6 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
Return the set of ids of `objs` that the target field references.
"""
from django.db.models import Model
target_ids = set()
target_field = self.through._meta.get_field(target_field_name)
for obj in objs:
@@ -1178,42 +1054,36 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
if not router.allow_relation(obj, self.instance):
raise ValueError(
'Cannot add "%r": instance is on database "%s", '
'value is on database "%s"'
% (obj, self.instance._state.db, obj._state.db)
'value is on database "%s"' %
(obj, self.instance._state.db, obj._state.db)
)
target_id = target_field.get_foreign_related_value(obj)[0]
if target_id is None:
raise ValueError(
'Cannot add "%r": the value for field "%s" is None'
% (obj, target_field_name)
'Cannot add "%r": the value for field "%s" is None' %
(obj, target_field_name)
)
target_ids.add(target_id)
elif isinstance(obj, Model):
raise TypeError(
"'%s' instance expected, got %r"
% (self.model._meta.object_name, obj)
"'%s' instance expected, got %r" %
(self.model._meta.object_name, obj)
)
else:
target_ids.add(target_field.get_prep_value(obj))
return target_ids
def _get_missing_target_ids(
self, source_field_name, target_field_name, db, target_ids
):
def _get_missing_target_ids(self, source_field_name, target_field_name, db, target_ids):
"""
Return the subset of ids of `objs` that aren't already assigned to
this relationship.
"""
vals = (
self.through._default_manager.using(db)
.values_list(target_field_name, flat=True)
.filter(
**{
source_field_name: self.related_val[0],
"%s__in" % target_field_name: target_ids,
}
)
)
vals = self.through._default_manager.using(db).values_list(
target_field_name, flat=True
).filter(**{
source_field_name: self.related_val[0],
'%s__in' % target_field_name: target_ids,
})
return target_ids.difference(vals)
def _get_add_plan(self, db, source_field_name):
@@ -1231,53 +1101,39 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
# user-defined intermediary models as they could have other fields
# causing conflicts which must be surfaced.
can_ignore_conflicts = (
connections[db].features.supports_ignore_conflicts
and self.through._meta.auto_created is not False
connections[db].features.supports_ignore_conflicts and
self.through._meta.auto_created is not False
)
# Don't send the signal when inserting duplicate data row
# for symmetrical reverse entries.
must_send_signals = (
self.reverse or source_field_name == self.source_field_name
) and (signals.m2m_changed.has_listeners(self.through))
must_send_signals = (self.reverse or source_field_name == self.source_field_name) and (
signals.m2m_changed.has_listeners(self.through)
)
# Fast addition through bulk insertion can only be performed
# if no m2m_changed listeners are connected for self.through
# as they require the added set of ids to be provided via
# pk_set.
return (
can_ignore_conflicts,
must_send_signals,
(can_ignore_conflicts and not must_send_signals),
)
return can_ignore_conflicts, must_send_signals, (can_ignore_conflicts and not must_send_signals)
def _add_items(
self, source_field_name, target_field_name, *objs, through_defaults=None
):
def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None):
# source_field_name: the PK fieldname in join table for the source object
# target_field_name: the PK fieldname in join table for the target object
# *objs - objects to add. Either object instances, or primary keys
# of object instances.
# *objs - objects to add. Either object instances, or primary keys of object instances.
if not objs:
return
through_defaults = dict(resolve_callables(through_defaults or {}))
target_ids = self._get_target_ids(target_field_name, objs)
db = router.db_for_write(self.through, instance=self.instance)
can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(
db, source_field_name
)
can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name)
if can_fast_add:
self.through._default_manager.using(db).bulk_create(
[
self.through(
**{
"%s_id" % source_field_name: self.related_val[0],
"%s_id" % target_field_name: target_id,
}
)
for target_id in target_ids
],
ignore_conflicts=True,
)
self.through._default_manager.using(db).bulk_create([
self.through(**{
'%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id,
})
for target_id in target_ids
], ignore_conflicts=True)
return
missing_target_ids = self._get_missing_target_ids(
@@ -1286,38 +1142,24 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
with transaction.atomic(using=db, savepoint=False):
if must_send_signals:
signals.m2m_changed.send(
sender=self.through,
action="pre_add",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=missing_target_ids,
using=db,
sender=self.through, action='pre_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=missing_target_ids, using=db,
)
# Add the ones that aren't there already.
self.through._default_manager.using(db).bulk_create(
[
self.through(
**through_defaults,
**{
"%s_id" % source_field_name: self.related_val[0],
"%s_id" % target_field_name: target_id,
},
)
for target_id in missing_target_ids
],
ignore_conflicts=can_ignore_conflicts,
)
self.through._default_manager.using(db).bulk_create([
self.through(**through_defaults, **{
'%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id,
})
for target_id in missing_target_ids
], ignore_conflicts=can_ignore_conflicts)
if must_send_signals:
signals.m2m_changed.send(
sender=self.through,
action="post_add",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=missing_target_ids,
using=db,
sender=self.through, action='post_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=missing_target_ids, using=db,
)
def _remove_items(self, source_field_name, target_field_name, *objs):
@@ -1341,32 +1183,23 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
with transaction.atomic(using=db, savepoint=False):
# Send a signal to the other end if need be.
signals.m2m_changed.send(
sender=self.through,
action="pre_remove",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=old_ids,
using=db,
sender=self.through, action="pre_remove",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=old_ids, using=db,
)
target_model_qs = super().get_queryset()
if target_model_qs._has_filters():
old_vals = target_model_qs.using(db).filter(
**{"%s__in" % self.target_field.target_field.attname: old_ids}
)
old_vals = target_model_qs.using(db).filter(**{
'%s__in' % self.target_field.target_field.attname: old_ids})
else:
old_vals = old_ids
filters = self._build_remove_filters(old_vals)
self.through._default_manager.using(db).filter(filters).delete()
signals.m2m_changed.send(
sender=self.through,
action="post_remove",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=old_ids,
using=db,
sender=self.through, action="post_remove",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=old_ids, using=db,
)
return ManyRelatedManager
@@ -1,10 +1,5 @@
from django.db.models.lookups import (
Exact,
GreaterThan,
GreaterThanOrEqual,
In,
IsNull,
LessThan,
Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan,
LessThanOrEqual,
)
@@ -13,40 +8,29 @@ class MultiColSource:
contains_aggregate = False
def __init__(self, alias, targets, sources, field):
self.targets, self.sources, self.field, self.alias = (
targets,
sources,
field,
alias,
)
self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
self.output_field = self.field
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
return "{}({}, {})".format(
self.__class__.__name__, self.alias, self.field)
def relabeled_clone(self, relabels):
return self.__class__(
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
)
return self.__class__(relabels.get(self.alias, self.alias),
self.targets, self.sources, self.field)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def resolve_expression(self, *args, **kwargs):
return self
def get_normalized_value(value, lhs):
from django.db.models import Model
if isinstance(value, Model):
value_list = []
sources = lhs.output_field.get_path_info()[-1].target_fields
for source in sources:
while not isinstance(value, source.model) and source.remote_field:
source = source.remote_field.model._meta.get_field(
source.remote_field.field_name
)
source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
try:
value_list.append(getattr(value, source.attname))
except AttributeError:
@@ -68,26 +52,20 @@ class RelatedIn(In):
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if hasattr(self.lhs.output_field, "get_path_info"):
if hasattr(self.lhs.output_field, 'get_path_info'):
# Run the target field's get_prep_value. We can safely assume there is
# only one as we don't get to the direct value branch otherwise.
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[
-1
]
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
# For multicolumn lookups we need to build a multicolumn where clause.
# This clause is either a SubqueryConstraint (for values that need
# to be compiled to SQL) or an OR-combined list of
# (col1 = val1 AND col2 = val2 AND ...) clauses.
# This clause is either a SubqueryConstraint (for values that need to be compiled to
# SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
from django.db.models.sql.where import (
AND,
OR,
SubqueryConstraint,
WhereNode,
AND, OR, SubqueryConstraint, WhereNode,
)
root_constraint = WhereNode(connector=OR)
@@ -95,35 +73,24 @@ class RelatedIn(In):
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
for value in values:
value_constraint = WhereNode()
for source, target, val in zip(
self.lhs.sources, self.lhs.targets, value
):
lookup_class = target.get_lookup("exact")
lookup = lookup_class(
target.get_col(self.lhs.alias, source), val
)
for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
lookup_class = target.get_lookup('exact')
lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
else:
root_constraint.add(
SubqueryConstraint(
self.lhs.alias,
[target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources],
self.rhs,
),
AND,
)
self.lhs.alias, [target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources], self.rhs),
AND)
return root_constraint.as_sql(compiler, connection)
else:
if not getattr(self.rhs, "has_select_fields", True) and not getattr(
self.lhs.field.target_field, "primary_key", False
):
if (not getattr(self.rhs, 'has_select_fields', True) and
not getattr(self.lhs.field.target_field, 'primary_key', False)):
self.rhs.clear_select_clause()
if (
getattr(self.lhs.output_field, "primary_key", False)
and self.lhs.output_field.model == self.rhs.model
):
if (getattr(self.lhs.output_field, 'primary_key', False) and
self.lhs.output_field.model == self.rhs.model):
# A case like Restaurant.objects.filter(place__in=restaurant_qs),
# where place is a OneToOneField and the primary key of
# Restaurant.
@@ -136,21 +103,17 @@ class RelatedIn(In):
class RelatedLookupMixin:
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource) and not hasattr(
self.rhs, "resolve_expression"
):
if not isinstance(self.lhs, MultiColSource) and not hasattr(self.rhs, 'resolve_expression'):
# If we get here, we are dealing with single-column relations.
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
# We need to run the related field's get_prep_value(). Consider case
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if self.prepare_rhs and hasattr(self.lhs.output_field, "get_path_info"):
if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_path_info'):
# Get the target field. We can safely assume there is only one
# as we don't get to the direct value branch otherwise.
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[
-1
]
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
self.rhs = target_field.get_prep_value(self.rhs)
return super().get_prep_lookup()
@@ -160,15 +123,11 @@ class RelatedLookupMixin:
assert self.rhs_is_direct_value()
self.rhs = get_normalized_value(self.rhs, self.lhs)
from django.db.models.sql.where import AND, WhereNode
root_constraint = WhereNode()
for target, source, val in zip(
self.lhs.targets, self.lhs.sources, self.rhs
):
for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
lookup_class = target.get_lookup(self.lookup_name)
root_constraint.add(
lookup_class(target.get_col(self.lhs.alias, source), val), AND
)
lookup_class(target.get_col(self.lhs.alias, source), val), AND)
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
@@ -36,16 +36,8 @@ class ForeignObjectRel(FieldCacheMixin):
null = True
empty_strings_allowed = False
def __init__(
self,
field,
to,
related_name=None,
related_query_name=None,
limit_choices_to=None,
parent_link=False,
on_delete=None,
):
def __init__(self, field, to, related_name=None, related_query_name=None,
limit_choices_to=None, parent_link=False, on_delete=None):
self.field = field
self.model = to
self.related_name = related_name
@@ -81,18 +73,14 @@ class ForeignObjectRel(FieldCacheMixin):
"""
target_fields = self.get_path_info()[-1].target_fields
if len(target_fields) > 1:
raise exceptions.FieldError(
"Can't use target_field for multicolumn relations."
)
raise exceptions.FieldError("Can't use target_field for multicolumn relations.")
return target_fields[0]
@cached_property
def related_model(self):
if not self.field.model:
raise AttributeError(
"This property can't be accessed before self.field.contribute_to_class "
"has been called."
)
"This property can't be accessed before self.field.contribute_to_class has been called.")
return self.field.model
@cached_property
@@ -122,7 +110,7 @@ class ForeignObjectRel(FieldCacheMixin):
return self.field.db_type
def __repr__(self):
return "<%s: %s.%s>" % (
return '<%s: %s.%s>' % (
type(self).__name__,
self.related_model._meta.app_label,
self.related_model._meta.model_name,
@@ -151,11 +139,8 @@ class ForeignObjectRel(FieldCacheMixin):
return hash(self.identity)
def get_choices(
self,
include_blank=True,
blank_choice=BLANK_CHOICE_DASH,
limit_choices_to=None,
ordering=(),
self, include_blank=True, blank_choice=BLANK_CHOICE_DASH,
limit_choices_to=None, ordering=(),
):
"""
Return choices with a default blank choices included, for use
@@ -168,17 +153,19 @@ class ForeignObjectRel(FieldCacheMixin):
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
if ordering:
qs = qs.order_by(*ordering)
return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
return (blank_choice if include_blank else []) + [
(x.pk, str(x)) for x in qs
]
def is_hidden(self):
"""Should the related object be hidden?"""
return bool(self.related_name) and self.related_name[-1] == "+"
return bool(self.related_name) and self.related_name[-1] == '+'
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
def get_extra_restriction(self, alias, related_alias):
return self.field.get_extra_restriction(related_alias, alias)
def get_extra_restriction(self, where_class, alias, related_alias):
return self.field.get_extra_restriction(where_class, related_alias, alias)
def set_field_name(self):
"""
@@ -200,13 +187,12 @@ class ForeignObjectRel(FieldCacheMixin):
opts = model._meta if model else self.related_model._meta
model = model or self.related_model
if self.multiple:
# If this is a symmetrical m2m relation on self, there is no
# reverse accessor.
# If this is a symmetrical m2m relation on self, there is no reverse accessor.
if self.symmetrical and model == self.model:
return None
if self.related_name:
return self.related_name
return opts.model_name + ("_set" if self.multiple else "")
return opts.model_name + ('_set' if self.multiple else '')
def get_path_info(self, filtered_relation=None):
return self.field.get_reverse_path_info(filtered_relation)
@@ -234,20 +220,10 @@ class ManyToOneRel(ForeignObjectRel):
reverse relations into actual fields.
"""
def __init__(
self,
field,
to,
field_name,
related_name=None,
related_query_name=None,
limit_choices_to=None,
parent_link=False,
on_delete=None,
):
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
limit_choices_to=None, parent_link=False, on_delete=None):
super().__init__(
field,
to,
field, to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -259,7 +235,7 @@ class ManyToOneRel(ForeignObjectRel):
def __getstate__(self):
state = self.__dict__.copy()
state.pop("related_model", None)
state.pop('related_model', None)
return state
@property
@@ -272,9 +248,7 @@ class ManyToOneRel(ForeignObjectRel):
"""
field = self.model._meta.get_field(self.field_name)
if not field.concrete:
raise exceptions.FieldDoesNotExist(
"No related field named '%s'" % self.field_name
)
raise exceptions.FieldDoesNotExist("No related field named '%s'" % self.field_name)
return field
def set_field_name(self):
@@ -289,21 +263,10 @@ class OneToOneRel(ManyToOneRel):
flags for the reverse relation.
"""
def __init__(
self,
field,
to,
field_name,
related_name=None,
related_query_name=None,
limit_choices_to=None,
parent_link=False,
on_delete=None,
):
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
limit_choices_to=None, parent_link=False, on_delete=None):
super().__init__(
field,
to,
field_name,
field, to, field_name,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -322,21 +285,11 @@ class ManyToManyRel(ForeignObjectRel):
flags for the reverse relation.
"""
def __init__(
self,
field,
to,
related_name=None,
related_query_name=None,
limit_choices_to=None,
symmetrical=True,
through=None,
through_fields=None,
db_constraint=True,
):
def __init__(self, field, to, related_name=None, related_query_name=None,
limit_choices_to=None, symmetrical=True, through=None,
through_fields=None, db_constraint=True):
super().__init__(
field,
to,
field, to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -357,7 +310,7 @@ class ManyToManyRel(ForeignObjectRel):
def identity(self):
return super().identity + (
self.through,
make_hashable(self.through_fields),
self.through_fields,
self.db_constraint,
)
@@ -371,7 +324,7 @@ class ManyToManyRel(ForeignObjectRel):
field = opts.get_field(self.through_fields[0])
else:
for field in opts.fields:
rel = getattr(field, "remote_field", None)
rel = getattr(field, 'remote_field', None)
if rel and rel.model == self.model:
break
return field.foreign_related_fields[0]
@@ -1,190 +1,46 @@
from .comparison import Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf
from .comparison import (
Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf,
)
from .datetime import (
Extract,
ExtractDay,
ExtractHour,
ExtractIsoWeekDay,
ExtractIsoYear,
ExtractMinute,
ExtractMonth,
ExtractQuarter,
ExtractSecond,
ExtractWeek,
ExtractWeekDay,
ExtractYear,
Now,
Trunc,
TruncDate,
TruncDay,
TruncHour,
TruncMinute,
TruncMonth,
TruncQuarter,
TruncSecond,
TruncTime,
TruncWeek,
Extract, ExtractDay, ExtractHour, ExtractIsoWeekDay, ExtractIsoYear,
ExtractMinute, ExtractMonth, ExtractQuarter, ExtractSecond, ExtractWeek,
ExtractWeekDay, ExtractYear, Now, Trunc, TruncDate, TruncDay, TruncHour,
TruncMinute, TruncMonth, TruncQuarter, TruncSecond, TruncTime, TruncWeek,
TruncYear,
)
from .math import (
Abs,
ACos,
ASin,
ATan,
ATan2,
Ceil,
Cos,
Cot,
Degrees,
Exp,
Floor,
Ln,
Log,
Mod,
Pi,
Power,
Radians,
Random,
Round,
Sign,
Sin,
Sqrt,
Tan,
Abs, ACos, ASin, ATan, ATan2, Ceil, Cos, Cot, Degrees, Exp, Floor, Ln, Log,
Mod, Pi, Power, Radians, Random, Round, Sign, Sin, Sqrt, Tan,
)
from .text import (
MD5,
SHA1,
SHA224,
SHA256,
SHA384,
SHA512,
Chr,
Concat,
ConcatPair,
Left,
Length,
Lower,
LPad,
LTrim,
Ord,
Repeat,
Replace,
Reverse,
Right,
RPad,
RTrim,
StrIndex,
Substr,
Trim,
Upper,
MD5, SHA1, SHA224, SHA256, SHA384, SHA512, Chr, Concat, ConcatPair, Left,
Length, Lower, LPad, LTrim, Ord, Repeat, Replace, Reverse, Right, RPad,
RTrim, StrIndex, Substr, Trim, Upper,
)
from .window import (
CumeDist,
DenseRank,
FirstValue,
Lag,
LastValue,
Lead,
NthValue,
Ntile,
PercentRank,
Rank,
RowNumber,
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
PercentRank, Rank, RowNumber,
)
__all__ = [
# comparison and conversion
"Cast",
"Coalesce",
"Collate",
"Greatest",
"JSONObject",
"Least",
"NullIf",
'Cast', 'Coalesce', 'Collate', 'Greatest', 'JSONObject', 'Least', 'NullIf',
# datetime
"Extract",
"ExtractDay",
"ExtractHour",
"ExtractMinute",
"ExtractMonth",
"ExtractQuarter",
"ExtractSecond",
"ExtractWeek",
"ExtractIsoWeekDay",
"ExtractWeekDay",
"ExtractIsoYear",
"ExtractYear",
"Now",
"Trunc",
"TruncDate",
"TruncDay",
"TruncHour",
"TruncMinute",
"TruncMonth",
"TruncQuarter",
"TruncSecond",
"TruncTime",
"TruncWeek",
"TruncYear",
'Extract', 'ExtractDay', 'ExtractHour', 'ExtractMinute', 'ExtractMonth',
'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractIsoWeekDay',
'ExtractWeekDay', 'ExtractIsoYear', 'ExtractYear', 'Now', 'Trunc',
'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth',
'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncWeek', 'TruncYear',
# math
"Abs",
"ACos",
"ASin",
"ATan",
"ATan2",
"Ceil",
"Cos",
"Cot",
"Degrees",
"Exp",
"Floor",
"Ln",
"Log",
"Mod",
"Pi",
"Power",
"Radians",
"Random",
"Round",
"Sign",
"Sin",
"Sqrt",
"Tan",
'Abs', 'ACos', 'ASin', 'ATan', 'ATan2', 'Ceil', 'Cos', 'Cot', 'Degrees',
'Exp', 'Floor', 'Ln', 'Log', 'Mod', 'Pi', 'Power', 'Radians', 'Random',
'Round', 'Sign', 'Sin', 'Sqrt', 'Tan',
# text
"MD5",
"SHA1",
"SHA224",
"SHA256",
"SHA384",
"SHA512",
"Chr",
"Concat",
"ConcatPair",
"Left",
"Length",
"Lower",
"LPad",
"LTrim",
"Ord",
"Repeat",
"Replace",
"Reverse",
"Right",
"RPad",
"RTrim",
"StrIndex",
"Substr",
"Trim",
"Upper",
'MD5', 'SHA1', 'SHA224', 'SHA256', 'SHA384', 'SHA512', 'Chr', 'Concat',
'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Ord', 'Repeat',
'Replace', 'Reverse', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr',
'Trim', 'Upper',
# window
"CumeDist",
"DenseRank",
"FirstValue",
"Lag",
"LastValue",
"Lead",
"NthValue",
"Ntile",
"PercentRank",
"Rank",
"RowNumber",
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
]
@@ -7,115 +7,84 @@ from django.utils.regex_helper import _lazy_re_compile
class Cast(Func):
"""Coerce an expression to a new field type."""
function = "CAST"
template = "%(function)s(%(expressions)s AS %(db_type)s)"
function = 'CAST'
template = '%(function)s(%(expressions)s AS %(db_type)s)'
def __init__(self, expression, output_field):
super().__init__(expression, output_field=output_field)
def as_sql(self, compiler, connection, **extra_context):
extra_context["db_type"] = self.output_field.cast_db_type(connection)
extra_context['db_type'] = self.output_field.cast_db_type(connection)
return super().as_sql(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
db_type = self.output_field.db_type(connection)
if db_type in {"datetime", "time"}:
if db_type in {'datetime', 'time'}:
# Use strftime as datetime/time don't keep fractional seconds.
template = "strftime(%%s, %(expressions)s)"
sql, params = super().as_sql(
compiler, connection, template=template, **extra_context
)
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
template = 'strftime(%%s, %(expressions)s)'
sql, params = super().as_sql(compiler, connection, template=template, **extra_context)
format_string = '%H:%M:%f' if db_type == 'time' else '%Y-%m-%d %H:%M:%f'
params.insert(0, format_string)
return sql, params
elif db_type == "date":
template = "date(%(expressions)s)"
return super().as_sql(
compiler, connection, template=template, **extra_context
)
elif db_type == 'date':
template = 'date(%(expressions)s)'
return super().as_sql(compiler, connection, template=template, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection, **extra_context):
template = None
output_type = self.output_field.get_internal_type()
# MySQL doesn't support explicit cast to float.
if output_type == "FloatField":
template = "(%(expressions)s + 0.0)"
if output_type == 'FloatField':
template = '(%(expressions)s + 0.0)'
# MariaDB doesn't support explicit cast to JSON.
elif output_type == "JSONField" and connection.mysql_is_mariadb:
elif output_type == 'JSONField' and connection.mysql_is_mariadb:
template = "JSON_EXTRACT(%(expressions)s, '$')"
return self.as_sql(compiler, connection, template=template, **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
# CAST would be valid too, but the :: shortcut syntax is more readable.
# 'expressions' is wrapped in parentheses in case it's a complex
# expression.
return self.as_sql(
compiler,
connection,
template="(%(expressions)s)::%(db_type)s",
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == "JSONField":
if self.output_field.get_internal_type() == 'JSONField':
# Oracle doesn't support explicit cast to JSON.
template = "JSON_QUERY(%(expressions)s, '$')"
return super().as_sql(
compiler, connection, template=template, **extra_context
)
return super().as_sql(compiler, connection, template=template, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
class Coalesce(Func):
"""Return, from left to right, the first non-null expression."""
function = "COALESCE"
function = 'COALESCE'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Coalesce must take at least two expressions")
raise ValueError('Coalesce must take at least two expressions')
super().__init__(*expressions, **extra)
@property
def empty_result_set_value(self):
for expression in self.get_source_expressions():
result = expression.empty_result_set_value
if result is NotImplemented or result is not None:
return result
return None
def as_oracle(self, compiler, connection, **extra_context):
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
# so convert all fields to NCLOB when that type is expected.
if self.output_field.get_internal_type() == "TextField":
if self.output_field.get_internal_type() == 'TextField':
clone = self.copy()
clone.set_source_expressions(
[
Func(expression, function="TO_NCLOB")
for expression in self.get_source_expressions()
]
)
clone.set_source_expressions([
Func(expression, function='TO_NCLOB') for expression in self.get_source_expressions()
])
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
return self.as_sql(compiler, connection, **extra_context)
class Collate(Func):
function = "COLLATE"
template = "%(expressions)s %(function)s %(collation)s"
# Inspired from
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
collation_re = _lazy_re_compile(r"^[\w\-]+$")
function = 'COLLATE'
template = '%(expressions)s %(function)s %(collation)s'
# Inspired from https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
collation_re = _lazy_re_compile(r'^[\w\-]+$')
def __init__(self, expression, collation):
if not (collation and self.collation_re.match(collation)):
raise ValueError("Invalid collation name: %r." % collation)
raise ValueError('Invalid collation name: %r.' % collation)
self.collation = collation
super().__init__(expression)
def as_sql(self, compiler, connection, **extra_context):
extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
extra_context.setdefault('collation', connection.ops.quote_name(self.collation))
return super().as_sql(compiler, connection, **extra_context)
@@ -127,21 +96,20 @@ class Greatest(Func):
On PostgreSQL, the maximum not-null expression is returned.
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
"""
function = "GREATEST"
function = 'GREATEST'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Greatest must take at least two expressions")
raise ValueError('Greatest must take at least two expressions')
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MAX function on SQLite."""
return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
class JSONObject(Func):
function = "JSON_OBJECT"
function = 'JSON_OBJECT'
output_field = JSONField()
def __init__(self, **fields):
@@ -153,7 +121,7 @@ class JSONObject(Func):
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.has_json_object_function:
raise NotSupportedError(
"JSONObject() is not supported on this database backend."
'JSONObject() is not supported on this database backend.'
)
return super().as_sql(compiler, connection, **extra_context)
@@ -161,21 +129,21 @@ class JSONObject(Func):
return self.as_sql(
compiler,
connection,
function="JSONB_BUILD_OBJECT",
function='JSONB_BUILD_OBJECT',
**extra_context,
)
def as_oracle(self, compiler, connection, **extra_context):
class ArgJoiner:
def join(self, args):
args = [" VALUE ".join(arg) for arg in zip(args[::2], args[1::2])]
return ", ".join(args)
args = [' VALUE '.join(arg) for arg in zip(args[::2], args[1::2])]
return ', '.join(args)
return self.as_sql(
compiler,
connection,
arg_joiner=ArgJoiner(),
template="%(function)s(%(expressions)s RETURNING CLOB)",
template='%(function)s(%(expressions)s RETURNING CLOB)',
**extra_context,
)
@@ -188,25 +156,24 @@ class Least(Func):
On PostgreSQL, return the minimum not-null expression.
On MySQL, Oracle, and SQLite, if any expression is null, return null.
"""
function = "LEAST"
function = 'LEAST'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Least must take at least two expressions")
raise ValueError('Least must take at least two expressions')
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
"""Use the MIN function on SQLite."""
return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
return super().as_sqlite(compiler, connection, function='MIN', **extra_context)
class NullIf(Func):
function = "NULLIF"
function = 'NULLIF'
arity = 2
def as_oracle(self, compiler, connection, **extra_context):
expression1 = self.get_source_expressions()[0]
if isinstance(expression1, Value) and expression1.value is None:
raise ValueError("Oracle does not allow Value(None) for expression1.")
raise ValueError('Oracle does not allow Value(None) for expression1.')
return super().as_sql(compiler, connection, **extra_context)
@@ -3,20 +3,10 @@ from datetime import datetime
from django.conf import settings
from django.db.models.expressions import Func
from django.db.models.fields import (
DateField,
DateTimeField,
DurationField,
Field,
IntegerField,
TimeField,
DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
)
from django.db.models.lookups import (
Transform,
YearExact,
YearGt,
YearGte,
YearLt,
YearLte,
Transform, YearExact, YearGt, YearGte, YearLt, YearLte,
)
from django.utils import timezone
@@ -46,7 +36,7 @@ class Extract(TimezoneMixin, Transform):
if self.lookup_name is None:
self.lookup_name = lookup_name
if self.lookup_name is None:
raise ValueError("lookup_name must be provided")
raise ValueError('lookup_name must be provided')
self.tzinfo = tzinfo
super().__init__(expression, **extra)
@@ -57,16 +47,14 @@ class Extract(TimezoneMixin, Transform):
tzname = self.get_tzname()
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.")
raise ValueError('tzinfo can only be used with DateTimeField.')
elif isinstance(lhs_output_field, DateField):
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, TimeField):
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
elif isinstance(lhs_output_field, DurationField):
if not connection.features.has_native_duration_field:
raise ValueError(
"Extract requires native DurationField database support."
)
raise ValueError('Extract requires native DurationField database support.')
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
else:
# resolve_expression has already validated the output_field so this
@@ -74,38 +62,22 @@ class Extract(TimezoneMixin, Transform):
assert False, "Tried to Extract from an invalid type."
return sql, params
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
copy = super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
field = getattr(copy.lhs, "output_field", None)
if field is None:
return copy
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
field = copy.lhs.output_field
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
raise ValueError(
"Extract input expression must be DateField, DateTimeField, "
"TimeField, or DurationField."
'Extract input expression must be DateField, DateTimeField, '
'TimeField, or DurationField.'
)
# Passing dates to functions expecting datetimes is most likely a mistake.
if type(field) == DateField and copy.lookup_name in (
"hour",
"minute",
"second",
):
if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
raise ValueError(
"Cannot extract time component '%s' from DateField '%s'."
% (copy.lookup_name, field.name)
"Cannot extract time component '%s' from DateField '%s'. " % (copy.lookup_name, field.name)
)
if isinstance(field, DurationField) and copy.lookup_name in (
"year",
"iso_year",
"month",
"week",
"week_day",
"iso_week_day",
"quarter",
if (
isinstance(field, DurationField) and
copy.lookup_name in ('year', 'iso_year', 'month', 'week', 'week_day', 'iso_week_day', 'quarter')
):
raise ValueError(
"Cannot extract component '%s' from DurationField '%s'."
@@ -115,21 +87,20 @@ class Extract(TimezoneMixin, Transform):
class ExtractYear(Extract):
lookup_name = "year"
lookup_name = 'year'
class ExtractIsoYear(Extract):
"""Return the ISO-8601 week-numbering year."""
lookup_name = "iso_year"
lookup_name = 'iso_year'
class ExtractMonth(Extract):
lookup_name = "month"
lookup_name = 'month'
class ExtractDay(Extract):
lookup_name = "day"
lookup_name = 'day'
class ExtractWeek(Extract):
@@ -137,8 +108,7 @@ class ExtractWeek(Extract):
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
week.
"""
lookup_name = "week"
lookup_name = 'week'
class ExtractWeekDay(Extract):
@@ -147,30 +117,28 @@ class ExtractWeekDay(Extract):
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
"""
lookup_name = "week_day"
lookup_name = 'week_day'
class ExtractIsoWeekDay(Extract):
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
lookup_name = "iso_week_day"
lookup_name = 'iso_week_day'
class ExtractQuarter(Extract):
lookup_name = "quarter"
lookup_name = 'quarter'
class ExtractHour(Extract):
lookup_name = "hour"
lookup_name = 'hour'
class ExtractMinute(Extract):
lookup_name = "minute"
lookup_name = 'minute'
class ExtractSecond(Extract):
lookup_name = "second"
lookup_name = 'second'
DateField.register_lookup(ExtractYear)
@@ -204,32 +172,21 @@ ExtractIsoYear.register_lookup(YearLte)
class Now(Func):
template = "CURRENT_TIMESTAMP"
template = 'CURRENT_TIMESTAMP'
output_field = DateTimeField()
def as_postgresql(self, compiler, connection, **extra_context):
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
# other databases.
return self.as_sql(
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
)
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
class TruncBase(TimezoneMixin, Transform):
kind = None
tzinfo = None
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(
self,
expression,
output_field=None,
tzinfo=None,
is_dst=timezone.NOT_PASSED,
**extra,
):
def __init__(self, expression, output_field=None, tzinfo=None, is_dst=None, **extra):
self.tzinfo = tzinfo
self.is_dst = is_dst
super().__init__(expression, output_field=output_field, **extra)
@@ -240,7 +197,7 @@ class TruncBase(TimezoneMixin, Transform):
if isinstance(self.lhs.output_field, DateTimeField):
tzname = self.get_tzname()
elif self.tzinfo is not None:
raise ValueError("tzinfo can only be used with DateTimeField.")
raise ValueError('tzinfo can only be used with DateTimeField.')
if isinstance(self.output_field, DateTimeField):
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
elif isinstance(self.output_field, DateField):
@@ -248,66 +205,36 @@ class TruncBase(TimezoneMixin, Transform):
elif isinstance(self.output_field, TimeField):
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
else:
raise ValueError(
"Trunc only valid on DateField, TimeField, or DateTimeField."
)
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
return sql, inner_params
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
copy = super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
field = copy.lhs.output_field
# DateTimeField is a subclass of DateField so this works for both.
if not isinstance(field, (DateField, TimeField)):
raise TypeError(
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
)
assert isinstance(field, (DateField, TimeField)), (
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
)
# If self.output_field was None, then accessing the field will trigger
# the resolver to assign it to self.lhs.output_field.
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
raise ValueError(
"output_field must be either DateField, TimeField, or DateTimeField"
)
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
# Passing dates or times to functions expecting datetimes is most
# likely a mistake.
class_output_field = (
self.__class__.output_field
if isinstance(self.__class__.output_field, Field)
else None
)
class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
output_field = class_output_field or copy.output_field
has_explicit_output_field = (
class_output_field or field.__class__ is not copy.output_field.__class__
)
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
if type(field) == DateField and (
isinstance(output_field, DateTimeField)
or copy.kind in ("hour", "minute", "second", "time")
):
raise ValueError(
"Cannot truncate DateField '%s' to %s."
% (
field.name,
output_field.__class__.__name__
if has_explicit_output_field
else "DateTimeField",
)
)
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
raise ValueError("Cannot truncate DateField '%s' to %s. " % (
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
))
elif isinstance(field, TimeField) and (
isinstance(output_field, DateTimeField)
or copy.kind in ("year", "quarter", "month", "week", "day", "date")
):
raise ValueError(
"Cannot truncate TimeField '%s' to %s."
% (
field.name,
output_field.__class__.__name__
if has_explicit_output_field
else "DateTimeField",
)
)
isinstance(output_field, DateTimeField) or
copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')):
raise ValueError("Cannot truncate TimeField '%s' to %s. " % (
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
))
return copy
def convert_value(self, value, expression, connection):
@@ -319,8 +246,8 @@ class TruncBase(TimezoneMixin, Transform):
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
elif not connection.features.has_zoneinfo_database:
raise ValueError(
"Database returned an invalid datetime value. Are time "
"zone definitions for your database installed?"
'Database returned an invalid datetime value. Are time '
'zone definitions for your database installed?'
)
elif isinstance(value, datetime):
if value is None:
@@ -334,48 +261,38 @@ class TruncBase(TimezoneMixin, Transform):
class Trunc(TruncBase):
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
# argument.
def __init__(
self,
expression,
kind,
output_field=None,
tzinfo=None,
is_dst=timezone.NOT_PASSED,
**extra,
):
def __init__(self, expression, kind, output_field=None, tzinfo=None, is_dst=None, **extra):
self.kind = kind
super().__init__(
expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra
expression, output_field=output_field, tzinfo=tzinfo,
is_dst=is_dst, **extra
)
class TruncYear(TruncBase):
kind = "year"
kind = 'year'
class TruncQuarter(TruncBase):
kind = "quarter"
kind = 'quarter'
class TruncMonth(TruncBase):
kind = "month"
kind = 'month'
class TruncWeek(TruncBase):
"""Truncate to midnight on the Monday of the week."""
kind = "week"
kind = 'week'
class TruncDay(TruncBase):
kind = "day"
kind = 'day'
class TruncDate(TruncBase):
kind = "date"
lookup_name = "date"
kind = 'date'
lookup_name = 'date'
output_field = DateField()
def as_sql(self, compiler, connection):
@@ -387,8 +304,8 @@ class TruncDate(TruncBase):
class TruncTime(TruncBase):
kind = "time"
lookup_name = "time"
kind = 'time'
lookup_name = 'time'
output_field = TimeField()
def as_sql(self, compiler, connection):
@@ -400,15 +317,15 @@ class TruncTime(TruncBase):
class TruncHour(TruncBase):
kind = "hour"
kind = 'hour'
class TruncMinute(TruncBase):
kind = "minute"
kind = 'minute'
class TruncSecond(TruncBase):
kind = "second"
kind = 'second'
DateTimeField.register_lookup(TruncDate)
@@ -1,43 +1,40 @@
import math
from django.db.models.expressions import Func, Value
from django.db.models.expressions import Func
from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions import Cast
from django.db.models.functions.mixins import (
FixDecimalInputMixin,
NumericOutputFieldMixin,
FixDecimalInputMixin, NumericOutputFieldMixin,
)
from django.db.models.lookups import Transform
class Abs(Transform):
function = "ABS"
lookup_name = "abs"
function = 'ABS'
lookup_name = 'abs'
class ACos(NumericOutputFieldMixin, Transform):
function = "ACOS"
lookup_name = "acos"
function = 'ACOS'
lookup_name = 'acos'
class ASin(NumericOutputFieldMixin, Transform):
function = "ASIN"
lookup_name = "asin"
function = 'ASIN'
lookup_name = 'asin'
class ATan(NumericOutputFieldMixin, Transform):
function = "ATAN"
lookup_name = "atan"
function = 'ATAN'
lookup_name = 'atan'
class ATan2(NumericOutputFieldMixin, Func):
function = "ATAN2"
function = 'ATAN2'
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(
connection.ops, "spatialite", False
) or connection.ops.spatial_version >= (5, 0, 0):
if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version >= (5, 0, 0):
return self.as_sql(compiler, connection)
# This function is usually ATan2(y, x), returning the inverse tangent
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
@@ -45,74 +42,67 @@ class ATan2(NumericOutputFieldMixin, Func):
# arguments are mixed between integer and float or decimal.
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
clone = self.copy()
clone.set_source_expressions(
[
Cast(expression, FloatField())
if isinstance(expression.output_field, IntegerField)
else expression
for expression in self.get_source_expressions()[::-1]
]
)
clone.set_source_expressions([
Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
else expression for expression in self.get_source_expressions()[::-1]
])
return clone.as_sql(compiler, connection, **extra_context)
class Ceil(Transform):
function = "CEILING"
lookup_name = "ceil"
function = 'CEILING'
lookup_name = 'ceil'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="CEIL", **extra_context)
return super().as_sql(compiler, connection, function='CEIL', **extra_context)
class Cos(NumericOutputFieldMixin, Transform):
function = "COS"
lookup_name = "cos"
function = 'COS'
lookup_name = 'cos'
class Cot(NumericOutputFieldMixin, Transform):
function = "COT"
lookup_name = "cot"
function = 'COT'
lookup_name = 'cot'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context
)
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
class Degrees(NumericOutputFieldMixin, Transform):
function = "DEGREES"
lookup_name = "degrees"
function = 'DEGREES'
lookup_name = 'degrees'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="((%%(expressions)s) * 180 / %s)" % math.pi,
**extra_context,
compiler, connection,
template='((%%(expressions)s) * 180 / %s)' % math.pi,
**extra_context
)
class Exp(NumericOutputFieldMixin, Transform):
function = "EXP"
lookup_name = "exp"
function = 'EXP'
lookup_name = 'exp'
class Floor(Transform):
function = "FLOOR"
lookup_name = "floor"
function = 'FLOOR'
lookup_name = 'floor'
class Ln(NumericOutputFieldMixin, Transform):
function = "LN"
lookup_name = "ln"
function = 'LN'
lookup_name = 'ln'
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = "LOG"
function = 'LOG'
arity = 2
def as_sqlite(self, compiler, connection, **extra_context):
if not getattr(connection.ops, "spatialite", False):
if not getattr(connection.ops, 'spatialite', False):
return self.as_sql(compiler, connection)
# This function is usually Log(b, x) returning the logarithm of x to
# the base b, but on SpatiaLite it's Log(x, b).
@@ -122,91 +112,72 @@ class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = "MOD"
function = 'MOD'
arity = 2
class Pi(NumericOutputFieldMixin, Func):
function = "PI"
function = 'PI'
arity = 0
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, template=str(math.pi), **extra_context
)
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
class Power(NumericOutputFieldMixin, Func):
function = "POWER"
function = 'POWER'
arity = 2
class Radians(NumericOutputFieldMixin, Transform):
function = "RADIANS"
lookup_name = "radians"
function = 'RADIANS'
lookup_name = 'radians'
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="((%%(expressions)s) * %s / 180)" % math.pi,
**extra_context,
compiler, connection,
template='((%%(expressions)s) * %s / 180)' % math.pi,
**extra_context
)
class Random(NumericOutputFieldMixin, Func):
function = "RANDOM"
function = 'RANDOM'
arity = 0
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="RAND", **extra_context)
return super().as_sql(compiler, connection, function='RAND', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context
)
return super().as_sql(compiler, connection, function='DBMS_RANDOM.VALUE', **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="RAND", **extra_context)
return super().as_sql(compiler, connection, function='RAND', **extra_context)
def get_group_by_cols(self, alias=None):
return []
class Round(FixDecimalInputMixin, Transform):
function = "ROUND"
lookup_name = "round"
arity = None # Override Transform's arity=1 to enable passing precision.
def __init__(self, expression, precision=0, **extra):
super().__init__(expression, precision, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
precision = self.get_source_expressions()[1]
if isinstance(precision, Value) and precision.value < 0:
raise ValueError("SQLite does not support negative precision.")
return super().as_sqlite(compiler, connection, **extra_context)
def _resolve_output_field(self):
source = self.get_source_expressions()[0]
return source.output_field
class Round(Transform):
function = 'ROUND'
lookup_name = 'round'
class Sign(Transform):
function = "SIGN"
lookup_name = "sign"
function = 'SIGN'
lookup_name = 'sign'
class Sin(NumericOutputFieldMixin, Transform):
function = "SIN"
lookup_name = "sin"
function = 'SIN'
lookup_name = 'sin'
class Sqrt(NumericOutputFieldMixin, Transform):
function = "SQRT"
lookup_name = "sqrt"
function = 'SQRT'
lookup_name = 'sqrt'
class Tan(NumericOutputFieldMixin, Transform):
function = "TAN"
lookup_name = "tan"
function = 'TAN'
lookup_name = 'tan'
@@ -5,6 +5,7 @@ from django.db.models.functions import Cast
class FixDecimalInputMixin:
def as_postgresql(self, compiler, connection, **extra_context):
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
# following function signatures:
@@ -12,42 +13,36 @@ class FixDecimalInputMixin:
# - MOD(double, double)
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
clone = self.copy()
clone.set_source_expressions(
[
Cast(expression, output_field)
if isinstance(expression.output_field, FloatField)
else expression
for expression in self.get_source_expressions()
]
)
clone.set_source_expressions([
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
else expression for expression in self.get_source_expressions()
])
return clone.as_sql(compiler, connection, **extra_context)
class FixDurationInputMixin:
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == "DurationField":
sql = "CAST(%s AS SIGNED)" % sql
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s AS SIGNED)' % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == "DurationField":
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
options = self._get_repr_options()
from django.db.backends.oracle.functions import (
IntervalToSeconds,
SecondsToInterval,
IntervalToSeconds, SecondsToInterval,
)
return compiler.compile(
SecondsToInterval(
self.__class__(IntervalToSeconds(expression), **options)
)
SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options))
)
return super().as_sql(compiler, connection, **extra_context)
class NumericOutputFieldMixin:
def _resolve_output_field(self):
source_fields = self.get_source_fields()
if any(isinstance(s, DecimalField) for s in source_fields):
@@ -10,7 +10,7 @@ class MySQLSHA2Mixin:
return super().as_sql(
compiler,
connection,
template="SHA2(%%(expressions)s, %s)" % self.function[3:],
template='SHA2(%%(expressions)s, %s)' % self.function[3:],
**extra_content,
)
@@ -40,28 +40,25 @@ class PostgreSQLSHAMixin:
class Chr(Transform):
function = "CHR"
lookup_name = "chr"
function = 'CHR'
lookup_name = 'chr'
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
function="CHAR",
template="%(function)s(%(expressions)s USING utf16)",
**extra_context,
compiler, connection, function='CHAR',
template='%(function)s(%(expressions)s USING utf16)',
**extra_context
)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template="%(function)s(%(expressions)s USING NCHAR_CS)",
**extra_context,
compiler, connection,
template='%(function)s(%(expressions)s USING NCHAR_CS)',
**extra_context
)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="CHAR", **extra_context)
return super().as_sql(compiler, connection, function='CHAR', **extra_context)
class ConcatPair(Func):
@@ -69,38 +66,29 @@ class ConcatPair(Func):
Concatenate two arguments together. This is used by `Concat` because not
all backend databases support more than two arguments.
"""
function = "CONCAT"
function = 'CONCAT'
def as_sqlite(self, compiler, connection, **extra_context):
coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql(
compiler,
connection,
template="%(expressions)s",
arg_joiner=" || ",
**extra_context,
compiler, connection, template='%(expressions)s', arg_joiner=' || ',
**extra_context
)
def as_mysql(self, compiler, connection, **extra_context):
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
return super().as_sql(
compiler,
connection,
function="CONCAT_WS",
compiler, connection, function='CONCAT_WS',
template="%(function)s('', %(expressions)s)",
**extra_context,
**extra_context
)
def coalesce(self):
# null on either side results in null for expression, wrap with coalesce
c = self.copy()
c.set_source_expressions(
[
Coalesce(expression, Value(""))
for expression in c.get_source_expressions()
]
)
c.set_source_expressions([
Coalesce(expression, Value('')) for expression in c.get_source_expressions()
])
return c
@@ -110,13 +98,12 @@ class Concat(Func):
null expression when any arguments are null will wrap each argument in
coalesce functions to ensure a non-null result.
"""
function = None
template = "%(expressions)s"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError("Concat must take at least two expressions")
raise ValueError('Concat must take at least two expressions')
paired = self._paired(expressions)
super().__init__(paired, **extra)
@@ -130,7 +117,7 @@ class Concat(Func):
class Left(Func):
function = "LEFT"
function = 'LEFT'
arity = 2
output_field = CharField()
@@ -139,7 +126,7 @@ class Left(Func):
expression: the name of a field, or an expression returning a string
length: the number of characters to return from the start of the string
"""
if not hasattr(length, "resolve_expression"):
if not hasattr(length, 'resolve_expression'):
if length < 1:
raise ValueError("'length' must be greater than 0.")
super().__init__(expression, length, **extra)
@@ -156,68 +143,57 @@ class Left(Func):
class Length(Transform):
"""Return the number of characters in the expression."""
function = "LENGTH"
lookup_name = "length"
function = 'LENGTH'
lookup_name = 'length'
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler, connection, function="CHAR_LENGTH", **extra_context
)
return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
class Lower(Transform):
function = "LOWER"
lookup_name = "lower"
function = 'LOWER'
lookup_name = 'lower'
class LPad(Func):
function = "LPAD"
function = 'LPAD'
output_field = CharField()
def __init__(self, expression, length, fill_text=Value(" "), **extra):
if (
not hasattr(length, "resolve_expression")
and length is not None
and length < 0
):
def __init__(self, expression, length, fill_text=Value(' '), **extra):
if not hasattr(length, 'resolve_expression') and length is not None and length < 0:
raise ValueError("'length' must be greater or equal to 0.")
super().__init__(expression, length, fill_text, **extra)
class LTrim(Transform):
function = "LTRIM"
lookup_name = "ltrim"
function = 'LTRIM'
lookup_name = 'ltrim'
class MD5(OracleHashMixin, Transform):
function = "MD5"
lookup_name = "md5"
function = 'MD5'
lookup_name = 'md5'
class Ord(Transform):
function = "ASCII"
lookup_name = "ord"
function = 'ASCII'
lookup_name = 'ord'
output_field = IntegerField()
def as_mysql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="ORD", **extra_context)
return super().as_sql(compiler, connection, function='ORD', **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
class Repeat(Func):
function = "REPEAT"
function = 'REPEAT'
output_field = CharField()
def __init__(self, expression, number, **extra):
if (
not hasattr(number, "resolve_expression")
and number is not None
and number < 0
):
if not hasattr(number, 'resolve_expression') and number is not None and number < 0:
raise ValueError("'number' must be greater or equal to 0.")
super().__init__(expression, number, **extra)
@@ -229,76 +205,73 @@ class Repeat(Func):
class Replace(Func):
function = "REPLACE"
function = 'REPLACE'
def __init__(self, expression, text, replacement=Value(""), **extra):
def __init__(self, expression, text, replacement=Value(''), **extra):
super().__init__(expression, text, replacement, **extra)
class Reverse(Transform):
function = "REVERSE"
lookup_name = "reverse"
function = 'REVERSE'
lookup_name = 'reverse'
def as_oracle(self, compiler, connection, **extra_context):
# REVERSE in Oracle is undocumented and doesn't support multi-byte
# strings. Use a special subquery instead.
return super().as_sql(
compiler,
connection,
compiler, connection,
template=(
"(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s "
"FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
"GROUP BY %(expressions)s)"
'(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM '
'(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s '
'FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) '
'GROUP BY %(expressions)s)'
),
**extra_context,
**extra_context
)
class Right(Left):
function = "RIGHT"
function = 'RIGHT'
def get_substr(self):
return Substr(
self.source_expressions[0], self.source_expressions[1] * Value(-1)
)
return Substr(self.source_expressions[0], self.source_expressions[1] * Value(-1))
class RPad(LPad):
function = "RPAD"
function = 'RPAD'
class RTrim(Transform):
function = "RTRIM"
lookup_name = "rtrim"
function = 'RTRIM'
lookup_name = 'rtrim'
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA1"
lookup_name = "sha1"
function = 'SHA1'
lookup_name = 'sha1'
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
function = "SHA224"
lookup_name = "sha224"
function = 'SHA224'
lookup_name = 'sha224'
def as_oracle(self, compiler, connection, **extra_context):
raise NotSupportedError("SHA224 is not supported on Oracle.")
raise NotSupportedError('SHA224 is not supported on Oracle.')
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA256"
lookup_name = "sha256"
function = 'SHA256'
lookup_name = 'sha256'
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA384"
lookup_name = "sha384"
function = 'SHA384'
lookup_name = 'sha384'
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
function = "SHA512"
lookup_name = "sha512"
function = 'SHA512'
lookup_name = 'sha512'
class StrIndex(Func):
@@ -307,17 +280,16 @@ class StrIndex(Func):
first occurrence of a substring inside another string, or 0 if the
substring is not found.
"""
function = "INSTR"
function = 'INSTR'
arity = 2
output_field = IntegerField()
def as_postgresql(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
class Substr(Func):
function = "SUBSTRING"
function = 'SUBSTRING'
output_field = CharField()
def __init__(self, expression, pos, length=None, **extra):
@@ -326,7 +298,7 @@ class Substr(Func):
pos: an integer > 0, or an expression returning an integer
length: an optional number of characters to return
"""
if not hasattr(pos, "resolve_expression"):
if not hasattr(pos, 'resolve_expression'):
if pos < 1:
raise ValueError("'pos' must be greater than 0")
expressions = [expression, pos]
@@ -335,17 +307,17 @@ class Substr(Func):
super().__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
class Trim(Transform):
function = "TRIM"
lookup_name = "trim"
function = 'TRIM'
lookup_name = 'trim'
class Upper(Transform):
function = "UPPER"
lookup_name = "upper"
function = 'UPPER'
lookup_name = 'upper'
@@ -2,35 +2,26 @@ from django.db.models.expressions import Func
from django.db.models.fields import FloatField, IntegerField
__all__ = [
"CumeDist",
"DenseRank",
"FirstValue",
"Lag",
"LastValue",
"Lead",
"NthValue",
"Ntile",
"PercentRank",
"Rank",
"RowNumber",
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
]
class CumeDist(Func):
function = "CUME_DIST"
function = 'CUME_DIST'
output_field = FloatField()
window_compatible = True
class DenseRank(Func):
function = "DENSE_RANK"
function = 'DENSE_RANK'
output_field = IntegerField()
window_compatible = True
class FirstValue(Func):
arity = 1
function = "FIRST_VALUE"
function = 'FIRST_VALUE'
window_compatible = True
@@ -40,12 +31,13 @@ class LagLeadFunction(Func):
def __init__(self, expression, offset=1, default=None, **extra):
if expression is None:
raise ValueError(
"%s requires a non-null source expression." % self.__class__.__name__
'%s requires a non-null source expression.' %
self.__class__.__name__
)
if offset is None or offset <= 0:
raise ValueError(
"%s requires a positive integer for the offset."
% self.__class__.__name__
'%s requires a positive integer for the offset.' %
self.__class__.__name__
)
args = (expression, offset)
if default is not None:
@@ -58,32 +50,28 @@ class LagLeadFunction(Func):
class Lag(LagLeadFunction):
function = "LAG"
function = 'LAG'
class LastValue(Func):
arity = 1
function = "LAST_VALUE"
function = 'LAST_VALUE'
window_compatible = True
class Lead(LagLeadFunction):
function = "LEAD"
function = 'LEAD'
class NthValue(Func):
function = "NTH_VALUE"
function = 'NTH_VALUE'
window_compatible = True
def __init__(self, expression, nth=1, **extra):
if expression is None:
raise ValueError(
"%s requires a non-null source expression." % self.__class__.__name__
)
raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__)
if nth is None or nth <= 0:
raise ValueError(
"%s requires a positive integer as for nth." % self.__class__.__name__
)
raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__)
super().__init__(expression, nth, **extra)
def _resolve_output_field(self):
@@ -92,29 +80,29 @@ class NthValue(Func):
class Ntile(Func):
function = "NTILE"
function = 'NTILE'
output_field = IntegerField()
window_compatible = True
def __init__(self, num_buckets=1, **extra):
if num_buckets <= 0:
raise ValueError("num_buckets must be greater than 0.")
raise ValueError('num_buckets must be greater than 0.')
super().__init__(num_buckets, **extra)
class PercentRank(Func):
function = "PERCENT_RANK"
function = 'PERCENT_RANK'
output_field = FloatField()
window_compatible = True
class Rank(Func):
function = "RANK"
function = 'RANK'
output_field = IntegerField()
window_compatible = True
class RowNumber(Func):
function = "ROW_NUMBER"
function = 'ROW_NUMBER'
output_field = IntegerField()
window_compatible = True
@@ -5,11 +5,11 @@ from django.db.models.query_utils import Q
from django.db.models.sql import Query
from django.utils.functional import partition
__all__ = ["Index"]
__all__ = ['Index']
class Index:
suffix = "idx"
suffix = 'idx'
# The max length of the name of the index (restricted to 30 for
# cross-database compatibility with Oracle)
max_name_length = 30
@@ -25,48 +25,46 @@ class Index:
include=None,
):
if opclasses and not name:
raise ValueError("An index must be named to use opclasses.")
raise ValueError('An index must be named to use opclasses.')
if not isinstance(condition, (type(None), Q)):
raise ValueError("Index.condition must be a Q instance.")
raise ValueError('Index.condition must be a Q instance.')
if condition and not name:
raise ValueError("An index must be named to use condition.")
raise ValueError('An index must be named to use condition.')
if not isinstance(fields, (list, tuple)):
raise ValueError("Index.fields must be a list or tuple.")
raise ValueError('Index.fields must be a list or tuple.')
if not isinstance(opclasses, (list, tuple)):
raise ValueError("Index.opclasses must be a list or tuple.")
raise ValueError('Index.opclasses must be a list or tuple.')
if not expressions and not fields:
raise ValueError(
"At least one field or expression is required to define an index."
'At least one field or expression is required to define an '
'index.'
)
if expressions and fields:
raise ValueError(
"Index.fields and expressions are mutually exclusive.",
'Index.fields and expressions are mutually exclusive.',
)
if expressions and not name:
raise ValueError("An index must be named to use expressions.")
raise ValueError('An index must be named to use expressions.')
if expressions and opclasses:
raise ValueError(
"Index.opclasses cannot be used with expressions. Use "
"django.contrib.postgres.indexes.OpClass() instead."
'Index.opclasses cannot be used with expressions. Use '
'django.contrib.postgres.indexes.OpClass() instead.'
)
if opclasses and len(fields) != len(opclasses):
raise ValueError(
"Index.fields and Index.opclasses must have the same number of "
"elements."
)
raise ValueError('Index.fields and Index.opclasses must have the same number of elements.')
if fields and not all(isinstance(field, str) for field in fields):
raise ValueError("Index.fields must contain only strings with field names.")
raise ValueError('Index.fields must contain only strings with field names.')
if include and not name:
raise ValueError("A covering index must be named.")
raise ValueError('A covering index must be named.')
if not isinstance(include, (type(None), list, tuple)):
raise ValueError("Index.include must be a list or tuple.")
raise ValueError('Index.include must be a list or tuple.')
self.fields = list(fields)
# A list of 2-tuple with the field name and ordering ('' or 'DESC').
self.fields_orders = [
(field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "")
(field_name[1:], 'DESC') if field_name.startswith('-') else (field_name, '')
for field_name in self.fields
]
self.name = name or ""
self.name = name or ''
self.db_tablespace = db_tablespace
self.opclasses = opclasses
self.condition = condition
@@ -89,10 +87,8 @@ class Index:
sql, params = where.as_sql(compiler, schema_editor.connection)
return sql % tuple(schema_editor.quote_value(p) for p in params)
def create_sql(self, model, schema_editor, using="", **kwargs):
include = [
model._meta.get_field(field_name).column for field_name in self.include
]
def create_sql(self, model, schema_editor, using='', **kwargs):
include = [model._meta.get_field(field_name).column for field_name in self.include]
condition = self._get_condition_sql(model, schema_editor)
if self.expressions:
index_expressions = []
@@ -113,36 +109,29 @@ class Index:
col_suffixes = [order[1] for order in self.fields_orders]
expressions = None
return schema_editor._create_index_sql(
model,
fields=fields,
name=self.name,
using=using,
db_tablespace=self.db_tablespace,
col_suffixes=col_suffixes,
opclasses=self.opclasses,
condition=condition,
include=include,
expressions=expressions,
**kwargs,
model, fields=fields, name=self.name, using=using,
db_tablespace=self.db_tablespace, col_suffixes=col_suffixes,
opclasses=self.opclasses, condition=condition, include=include,
expressions=expressions, **kwargs,
)
def remove_sql(self, model, schema_editor, **kwargs):
return schema_editor._delete_index_sql(model, self.name, **kwargs)
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
path = path.replace("django.db.models.indexes", "django.db.models")
kwargs = {"name": self.name}
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
path = path.replace('django.db.models.indexes', 'django.db.models')
kwargs = {'name': self.name}
if self.fields:
kwargs["fields"] = self.fields
kwargs['fields'] = self.fields
if self.db_tablespace is not None:
kwargs["db_tablespace"] = self.db_tablespace
kwargs['db_tablespace'] = self.db_tablespace
if self.opclasses:
kwargs["opclasses"] = self.opclasses
kwargs['opclasses'] = self.opclasses
if self.condition:
kwargs["condition"] = self.condition
kwargs['condition'] = self.condition
if self.include:
kwargs["include"] = self.include
kwargs['include'] = self.include
return (path, self.expressions, kwargs)
def clone(self):
@@ -159,44 +148,36 @@ class Index:
fit its size by truncating the excess length.
"""
_, table_name = split_identifier(model._meta.db_table)
column_names = [
model._meta.get_field(field_name).column
for field_name, order in self.fields_orders
]
column_names = [model._meta.get_field(field_name).column for field_name, order in self.fields_orders]
column_names_with_order = [
(("-%s" if order else "%s") % column_name)
for column_name, (field_name, order) in zip(
column_names, self.fields_orders
)
(('-%s' if order else '%s') % column_name)
for column_name, (field_name, order) in zip(column_names, self.fields_orders)
]
# The length of the parts of the name is based on the default max
# length of 30 characters.
hash_data = [table_name] + column_names_with_order + [self.suffix]
self.name = "%s_%s_%s" % (
self.name = '%s_%s_%s' % (
table_name[:11],
column_names[0][:7],
"%s_%s" % (names_digest(*hash_data, length=6), self.suffix),
'%s_%s' % (names_digest(*hash_data, length=6), self.suffix),
)
if len(self.name) > self.max_name_length:
raise ValueError(
"Index too long for multiple database support. Is self.suffix "
"longer than 3 characters?"
)
if self.name[0] == "_" or self.name[0].isdigit():
self.name = "D%s" % self.name[1:]
assert len(self.name) <= self.max_name_length, (
'Index too long for multiple database support. Is self.suffix '
'longer than 3 characters?'
)
if self.name[0] == '_' or self.name[0].isdigit():
self.name = 'D%s' % self.name[1:]
def __repr__(self):
return "<%s:%s%s%s%s%s%s%s>" % (
self.__class__.__qualname__,
"" if not self.fields else " fields=%s" % repr(self.fields),
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
"" if not self.name else " name=%s" % repr(self.name),
""
if self.db_tablespace is None
else " db_tablespace=%s" % repr(self.db_tablespace),
"" if self.condition is None else " condition=%s" % self.condition,
"" if not self.include else " include=%s" % repr(self.include),
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
return '<%s:%s%s%s%s%s>' % (
self.__class__.__name__,
'' if not self.fields else " fields='%s'" % ', '.join(self.fields),
'' if not self.expressions else " expressions='%s'" % ', '.join([
str(expression) for expression in self.expressions
]),
'' if self.condition is None else ' condition=%s' % self.condition,
'' if not self.include else " include='%s'" % ', '.join(self.include),
'' if not self.opclasses else " opclasses='%s'" % ', '.join(self.opclasses),
)
def __eq__(self, other):
@@ -207,20 +188,17 @@ class Index:
class IndexExpression(Func):
"""Order and wrap expressions for CREATE INDEX statements."""
template = "%(expressions)s"
template = '%(expressions)s'
wrapper_classes = (OrderBy, Collate)
def set_wrapper_classes(self, connection=None):
# Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
if connection and connection.features.collate_as_index_expression:
self.wrapper_classes = tuple(
[
wrapper_cls
for wrapper_cls in self.wrapper_classes
if wrapper_cls is not Collate
]
)
self.wrapper_classes = tuple([
wrapper_cls
for wrapper_cls in self.wrapper_classes
if wrapper_cls is not Collate
])
@classmethod
def register_wrappers(cls, *wrapper_classes):
@@ -244,17 +222,16 @@ class IndexExpression(Func):
if len(wrapper_types) != len(set(wrapper_types)):
raise ValueError(
"Multiple references to %s can't be used in an indexed "
"expression."
% ", ".join(
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
)
"expression." % ', '.join([
wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
])
)
if expressions[1 : len(wrappers) + 1] != wrappers:
if expressions[1:len(wrappers) + 1] != wrappers:
raise ValueError(
"%s must be topmost expressions in an indexed expression."
% ", ".join(
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
)
'%s must be topmost expressions in an indexed expression.'
% ', '.join([
wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
])
)
# Wrap expressions in parentheses if they are not column references.
root_expression = index_expressions[1]
@@ -266,7 +243,7 @@ class IndexExpression(Func):
for_save,
)
if not isinstance(resolve_root_expression, Col):
root_expression = Func(root_expression, template="(%(expressions)s)")
root_expression = Func(root_expression, template='(%(expressions)s)')
if wrappers:
# Order wrappers and set their expressions.
@@ -283,9 +260,7 @@ class IndexExpression(Func):
else:
# Use the root expression, if there are no wrappers.
self.set_source_expressions([root_expression])
return super().resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
return super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
def as_sqlite(self, compiler, connection, **extra_context):
# Casting to numeric is unnecessary.
+128 -199
View File
@@ -1,23 +1,21 @@
import itertools
import math
import warnings
from copy import copy
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, Expression, Func, Value, When
from django.db.models.expressions import Case, Exists, Func, Value, When
from django.db.models.fields import (
BooleanField,
CharField,
DateTimeField,
Field,
IntegerField,
UUIDField,
CharField, DateTimeField, Field, IntegerField, UUIDField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet
from django.utils.deprecation import RemovedInDjango40Warning
from django.utils.functional import cached_property
from django.utils.hashable import make_hashable
class Lookup(Expression):
class Lookup:
lookup_name = None
prepare_rhs = True
can_use_none_as_rhs = False
@@ -25,20 +23,18 @@ class Lookup(Expression):
def __init__(self, lhs, rhs):
self.lhs, self.rhs = lhs, rhs
self.rhs = self.get_prep_lookup()
self.lhs = self.get_prep_lhs()
if hasattr(self.lhs, "get_bilateral_transforms"):
if hasattr(self.lhs, 'get_bilateral_transforms'):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
if bilateral_transforms:
# Warn the user as soon as possible if they are trying to apply
# a bilateral transformation on a nested QuerySet: that won't work.
from django.db.models.sql.query import Query # avoid circular import
from django.db.models.sql.query import ( # avoid circular import
Query,
)
if isinstance(rhs, Query):
raise NotImplementedError(
"Bilateral transformations on nested querysets are not implemented."
)
raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.")
self.bilateral_transforms = bilateral_transforms
def apply_bilateral_transforms(self, value):
@@ -46,9 +42,6 @@ class Lookup(Expression):
value = transform(value)
return value
def __repr__(self):
return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
def batch_process_rhs(self, compiler, connection, rhs=None):
if rhs is None:
rhs = self.rhs
@@ -63,7 +56,7 @@ class Lookup(Expression):
sqls_params.extend(sql_params)
else:
_, params = self.get_db_prep_lookup(rhs, connection)
sqls, sqls_params = ["%s"] * len(params), params
sqls, sqls_params = ['%s'] * len(params), params
return sqls, sqls_params
def get_source_expressions(self):
@@ -78,32 +71,20 @@ class Lookup(Expression):
self.lhs, self.rhs = new_exprs
def get_prep_lookup(self):
if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
if hasattr(self.rhs, 'resolve_expression'):
return self.rhs
if hasattr(self.lhs, "output_field"):
if hasattr(self.lhs.output_field, "get_prep_value"):
return self.lhs.output_field.get_prep_value(self.rhs)
elif self.rhs_is_direct_value():
return Value(self.rhs)
if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
return self.lhs.output_field.get_prep_value(self.rhs)
return self.rhs
def get_prep_lhs(self):
if hasattr(self.lhs, "resolve_expression"):
return self.lhs
return Value(self.lhs)
def get_db_prep_lookup(self, value, connection):
return ("%s", [value])
return ('%s', [value])
def process_lhs(self, compiler, connection, lhs=None):
lhs = lhs or self.lhs
if hasattr(lhs, "resolve_expression"):
if hasattr(lhs, 'resolve_expression'):
lhs = lhs.resolve_expression(compiler.query)
sql, params = compiler.compile(lhs)
if isinstance(lhs, Lookup):
# Wrapped in parentheses to respect operator precedence.
sql = f"({sql})"
return sql, params
return compiler.compile(lhs)
def process_rhs(self, compiler, connection):
value = self.rhs
@@ -114,33 +95,37 @@ class Lookup(Expression):
value = Value(value, output_field=self.lhs.output_field)
value = self.apply_bilateral_transforms(value)
value = value.resolve_expression(compiler.query)
if hasattr(value, "as_sql"):
sql, params = compiler.compile(value)
# Ensure expression is wrapped in parentheses to respect operator
# precedence but avoid double wrapping as it can be misinterpreted
# on some backends (e.g. subqueries on SQLite).
if sql and sql[0] != "(":
sql = "(%s)" % sql
return sql, params
if hasattr(value, 'as_sql'):
return compiler.compile(value)
else:
return self.get_db_prep_lookup(value, connection)
def rhs_is_direct_value(self):
return not hasattr(self.rhs, "as_sql")
return not hasattr(self.rhs, 'as_sql')
def relabeled_clone(self, relabels):
new = copy(self)
new.lhs = new.lhs.relabeled_clone(relabels)
if hasattr(new.rhs, 'relabeled_clone'):
new.rhs = new.rhs.relabeled_clone(relabels)
return new
def get_group_by_cols(self, alias=None):
cols = []
for source in self.get_source_expressions():
cols.extend(source.get_group_by_cols())
cols = self.lhs.get_group_by_cols()
if hasattr(self.rhs, 'get_group_by_cols'):
cols.extend(self.rhs.get_group_by_cols())
return cols
def as_sql(self, compiler, connection):
raise NotImplementedError
def as_oracle(self, compiler, connection):
# Oracle doesn't allow EXISTS() and filters to be compared to another
# expression unless they're wrapped in a CASE WHEN.
# Oracle doesn't allow EXISTS() to be compared to another expression
# unless it's wrapped in a CASE WHEN.
wrapped = False
exprs = []
for expr in (self.lhs, self.rhs):
if connection.ops.conditional_expression_supported_in_where_clause(expr):
if isinstance(expr, Exists):
expr = Case(When(expr, then=True), default=False)
wrapped = True
exprs.append(expr)
@@ -148,8 +133,16 @@ class Lookup(Expression):
return lookup.as_sql(compiler, connection)
@cached_property
def output_field(self):
return BooleanField()
def contains_aggregate(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
@cached_property
def contains_over_clause(self):
return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
@property
def is_summary(self):
return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
@property
def identity(self):
@@ -163,34 +156,12 @@ class Lookup(Expression):
def __hash__(self):
return hash(make_hashable(self.identity))
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
c = self.copy()
c.is_summary = summarize
c.lhs = self.lhs.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
c.rhs = self.rhs.resolve_expression(
query, allow_joins, reuse, summarize, for_save
)
return c
def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params
class Transform(RegisterLookupMixin, Func):
"""
RegisterLookupMixin() is first so that get_lookup() and get_transform()
first examine self and then check output_field.
"""
bilateral = False
arity = 1
@@ -199,7 +170,7 @@ class Transform(RegisterLookupMixin, Func):
return self.get_source_expressions()[0]
def get_bilateral_transforms(self):
if hasattr(self.lhs, "get_bilateral_transforms"):
if hasattr(self.lhs, 'get_bilateral_transforms'):
bilateral_transforms = self.lhs.get_bilateral_transforms()
else:
bilateral_transforms = []
@@ -213,10 +184,9 @@ class BuiltinLookup(Lookup):
lhs_sql, params = super().process_lhs(compiler, connection, lhs)
field_internal_type = self.lhs.output_field.get_internal_type()
db_type = self.lhs.output_field.db_type(connection=connection)
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
lhs_sql = (
connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
)
lhs_sql = connection.ops.field_cast_sql(
db_type, field_internal_type) % lhs_sql
lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
return lhs_sql, list(params)
def as_sql(self, compiler, connection):
@@ -224,7 +194,7 @@ class BuiltinLookup(Lookup):
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
params.extend(rhs_params)
rhs_sql = self.get_rhs_op(connection, rhs_sql)
return "%s %s" % (lhs_sql, rhs_sql), params
return '%s %s' % (lhs_sql, rhs_sql), params
def get_rhs_op(self, connection, rhs):
return connection.operators[self.lookup_name] % rhs
@@ -235,22 +205,18 @@ class FieldGetDbPrepValueMixin:
Some lookups require Field.get_db_prep_value() to be called on their
inputs.
"""
get_db_prep_lookup_value_is_iterable = False
def get_db_prep_lookup(self, value, connection):
# For relational fields, use the 'target_field' attribute of the
# output_field.
field = getattr(self.lhs.output_field, "target_field", None)
get_db_prep_value = (
getattr(field, "get_db_prep_value", None)
or self.lhs.output_field.get_db_prep_value
)
field = getattr(self.lhs.output_field, 'target_field', None)
get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value
return (
"%s",
'%s',
[get_db_prep_value(v, connection, prepared=True) for v in value]
if self.get_db_prep_lookup_value_is_iterable
else [get_db_prep_value(value, connection, prepared=True)],
if self.get_db_prep_lookup_value_is_iterable else
[get_db_prep_value(value, connection, prepared=True)]
)
@@ -259,19 +225,18 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
Some lookups require Field.get_db_prep_value() to be called on each value
in an iterable.
"""
get_db_prep_lookup_value_is_iterable = True
def get_prep_lookup(self):
if hasattr(self.rhs, "resolve_expression"):
if hasattr(self.rhs, 'resolve_expression'):
return self.rhs
prepared_values = []
for rhs_value in self.rhs:
if hasattr(rhs_value, "resolve_expression"):
if hasattr(rhs_value, 'resolve_expression'):
# An expression will be handled by the database but can coexist
# alongside real values.
pass
elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
prepared_values.append(rhs_value)
return prepared_values
@@ -286,9 +251,9 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
def resolve_expression_parameter(self, compiler, connection, sql, param):
params = [param]
if hasattr(param, "resolve_expression"):
if hasattr(param, 'resolve_expression'):
param = param.resolve_expression(compiler.query)
if hasattr(param, "as_sql"):
if hasattr(param, 'as_sql'):
sql, params = compiler.compile(param)
return sql, params
@@ -298,44 +263,40 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
# sql/param pair. Zip them to get sql and param pairs that refer to the
# same argument and attempt to replace them with the result of
# compiling the param step.
sql, params = zip(
*(
self.resolve_expression_parameter(compiler, connection, sql, param)
for sql, param in zip(*pre_processed)
)
)
sql, params = zip(*(
self.resolve_expression_parameter(compiler, connection, sql, param)
for sql, param in zip(*pre_processed)
))
params = itertools.chain.from_iterable(params)
return sql, tuple(params)
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
"""Lookup defined by operators on PostgreSQL."""
postgres_operator = None
def as_postgresql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return "%s %s %s" % (lhs, self.postgres_operator, rhs), params
return '%s %s %s' % (lhs, self.postgres_operator, rhs), params
@Field.register_lookup
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "exact"
lookup_name = 'exact'
def process_rhs(self, compiler, connection):
from django.db.models.sql.query import Query
if isinstance(self.rhs, Query):
if self.rhs.has_limit_one():
if not self.rhs.has_select_fields:
self.rhs.clear_select_clause()
self.rhs.add_fields(["pk"])
self.rhs.add_fields(['pk'])
else:
raise ValueError(
"The QuerySet value for an exact lookup must be limited to "
"one result using slicing."
'The QuerySet value for an exact lookup must be limited to '
'one result using slicing.'
)
return super().process_rhs(compiler, connection)
@@ -344,21 +305,19 @@ class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
# turns "boolfield__exact=True" into "WHERE boolean_field" instead of
# "WHERE boolean_field = True" when allowed.
if (
isinstance(self.rhs, bool)
and getattr(self.lhs, "conditional", False)
and connection.ops.conditional_expression_supported_in_where_clause(
self.lhs
)
isinstance(self.rhs, bool) and
getattr(self.lhs, 'conditional', False) and
connection.ops.conditional_expression_supported_in_where_clause(self.lhs)
):
lhs_sql, params = self.process_lhs(compiler, connection)
template = "%s" if self.rhs else "NOT %s"
template = '%s' if self.rhs else 'NOT %s'
return template % lhs_sql, params
return super().as_sql(compiler, connection)
@Field.register_lookup
class IExact(BuiltinLookup):
lookup_name = "iexact"
lookup_name = 'iexact'
prepare_rhs = False
def process_rhs(self, qn, connection):
@@ -370,22 +329,22 @@ class IExact(BuiltinLookup):
@Field.register_lookup
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "gt"
lookup_name = 'gt'
@Field.register_lookup
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "gte"
lookup_name = 'gte'
@Field.register_lookup
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "lt"
lookup_name = 'lt'
@Field.register_lookup
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
lookup_name = "lte"
lookup_name = 'lte'
class IntegerFieldFloatRounding:
@@ -393,7 +352,6 @@ class IntegerFieldFloatRounding:
Allow floats to work as query values for IntegerField. Without this, the
decimal portion of the float would always be discarded.
"""
def get_prep_lookup(self):
if isinstance(self.rhs, float):
self.rhs = math.ceil(self.rhs)
@@ -412,10 +370,10 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
@Field.register_lookup
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = "in"
lookup_name = 'in'
def process_rhs(self, compiler, connection):
db_rhs = getattr(self.rhs, "_db", None)
db_rhs = getattr(self.rhs, '_db', None)
if db_rhs is not None and db_rhs != connection.alias:
raise ValueError(
"Subqueries aren't allowed across different databases. Force "
@@ -436,39 +394,20 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
# rhs should be an iterable; use batch_process_rhs() to
# prepare/transform those values.
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
placeholder = "(" + ", ".join(sqls) + ")"
placeholder = '(' + ', '.join(sqls) + ')'
return (placeholder, sqls_params)
else:
from django.db.models.sql.query import Query # avoid circular import
if isinstance(self.rhs, Query):
query = self.rhs
query.clear_ordering(clear_default=True)
if not query.has_select_fields:
query.clear_select_clause()
query.add_fields(["pk"])
if not getattr(self.rhs, 'has_select_fields', True):
self.rhs.clear_select_clause()
self.rhs.add_fields(['pk'])
return super().process_rhs(compiler, connection)
def get_group_by_cols(self, alias=None):
cols = self.lhs.get_group_by_cols()
if hasattr(self.rhs, "get_group_by_cols"):
if not getattr(self.rhs, "has_select_fields", True):
self.rhs.clear_select_clause()
self.rhs.add_fields(["pk"])
cols.extend(self.rhs.get_group_by_cols())
return cols
def get_rhs_op(self, connection, rhs):
return "IN %s" % rhs
return 'IN %s' % rhs
def as_sql(self, compiler, connection):
max_in_list_size = connection.ops.max_in_list_size()
if (
self.rhs_is_direct_value()
and max_in_list_size
and len(self.rhs) > max_in_list_size
):
if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
return self.split_parameter_list_as_sql(compiler, connection)
return super().as_sql(compiler, connection)
@@ -478,25 +417,25 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
max_in_list_size = connection.ops.max_in_list_size()
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
in_clause_elements = ["("]
in_clause_elements = ['(']
params = []
for offset in range(0, len(rhs_params), max_in_list_size):
if offset > 0:
in_clause_elements.append(" OR ")
in_clause_elements.append("%s IN (" % lhs)
in_clause_elements.append(' OR ')
in_clause_elements.append('%s IN (' % lhs)
params.extend(lhs_params)
sqls = rhs[offset : offset + max_in_list_size]
sqls_params = rhs_params[offset : offset + max_in_list_size]
param_group = ", ".join(sqls)
sqls = rhs[offset: offset + max_in_list_size]
sqls_params = rhs_params[offset: offset + max_in_list_size]
param_group = ', '.join(sqls)
in_clause_elements.append(param_group)
in_clause_elements.append(")")
in_clause_elements.append(')')
params.extend(sqls_params)
in_clause_elements.append(")")
return "".join(in_clause_elements), params
in_clause_elements.append(')')
return ''.join(in_clause_elements), params
class PatternLookup(BuiltinLookup):
param_pattern = "%%%s%%"
param_pattern = '%%%s%%'
prepare_rhs = False
def get_rhs_op(self, connection, rhs):
@@ -509,10 +448,8 @@ class PatternLookup(BuiltinLookup):
# So, for Python values we don't need any special pattern, but for
# SQL reference values or SQL transformations we need the correct
# pattern added.
if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
pattern = connection.pattern_ops[self.lookup_name].format(
connection.pattern_esc
)
if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms:
pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
return pattern.format(rhs)
else:
return super().get_rhs_op(connection, rhs)
@@ -520,47 +457,45 @@ class PatternLookup(BuiltinLookup):
def process_rhs(self, qn, connection):
rhs, params = super().process_rhs(qn, connection)
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
params[0] = self.param_pattern % connection.ops.prep_for_like_query(
params[0]
)
params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0])
return rhs, params
@Field.register_lookup
class Contains(PatternLookup):
lookup_name = "contains"
lookup_name = 'contains'
@Field.register_lookup
class IContains(Contains):
lookup_name = "icontains"
lookup_name = 'icontains'
@Field.register_lookup
class StartsWith(PatternLookup):
lookup_name = "startswith"
param_pattern = "%s%%"
lookup_name = 'startswith'
param_pattern = '%s%%'
@Field.register_lookup
class IStartsWith(StartsWith):
lookup_name = "istartswith"
lookup_name = 'istartswith'
@Field.register_lookup
class EndsWith(PatternLookup):
lookup_name = "endswith"
param_pattern = "%%%s"
lookup_name = 'endswith'
param_pattern = '%%%s'
@Field.register_lookup
class IEndsWith(EndsWith):
lookup_name = "iendswith"
lookup_name = 'iendswith'
@Field.register_lookup
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
lookup_name = "range"
lookup_name = 'range'
def get_rhs_op(self, connection, rhs):
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
@@ -568,13 +503,20 @@ class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
@Field.register_lookup
class IsNull(BuiltinLookup):
lookup_name = "isnull"
lookup_name = 'isnull'
prepare_rhs = False
def as_sql(self, compiler, connection):
if not isinstance(self.rhs, bool):
raise ValueError(
"The QuerySet value for an isnull lookup must be True or False."
# When the deprecation ends, replace with:
# raise ValueError(
# 'The QuerySet value for an isnull lookup must be True or '
# 'False.'
# )
warnings.warn(
'Using a non-boolean value for an isnull lookup is '
'deprecated, use True or False instead.',
RemovedInDjango40Warning,
)
sql, params = compiler.compile(self.lhs)
if self.rhs:
@@ -585,7 +527,7 @@ class IsNull(BuiltinLookup):
@Field.register_lookup
class Regex(BuiltinLookup):
lookup_name = "regex"
lookup_name = 'regex'
prepare_rhs = False
def as_sql(self, compiler, connection):
@@ -600,25 +542,16 @@ class Regex(BuiltinLookup):
@Field.register_lookup
class IRegex(Regex):
lookup_name = "iregex"
lookup_name = 'iregex'
class YearLookup(Lookup):
def year_lookup_bounds(self, connection, year):
from django.db.models.functions import ExtractIsoYear
iso_year = isinstance(self.lhs, ExtractIsoYear)
output_field = self.lhs.lhs.output_field
if isinstance(output_field, DateTimeField):
bounds = connection.ops.year_lookup_bounds_for_datetime_field(
year,
iso_year=iso_year,
)
bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
else:
bounds = connection.ops.year_lookup_bounds_for_date_field(
year,
iso_year=iso_year,
)
bounds = connection.ops.year_lookup_bounds_for_date_field(year)
return bounds
def as_sql(self, compiler, connection):
@@ -632,7 +565,7 @@ class YearLookup(Lookup):
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
start, finish = self.year_lookup_bounds(connection, self.rhs)
params.extend(self.get_bound_params(start, finish))
return "%s %s" % (lhs_sql, rhs_sql), params
return '%s %s' % (lhs_sql, rhs_sql), params
return super().as_sql(compiler, connection)
def get_direct_rhs_sql(self, connection, rhs):
@@ -640,13 +573,13 @@ class YearLookup(Lookup):
def get_bound_params(self, start, finish):
raise NotImplementedError(
"subclasses of YearLookup must provide a get_bound_params() method"
'subclasses of YearLookup must provide a get_bound_params() method'
)
class YearExact(YearLookup, Exact):
def get_direct_rhs_sql(self, connection, rhs):
return "BETWEEN %s AND %s"
return 'BETWEEN %s AND %s'
def get_bound_params(self, start, finish):
return (start, finish)
@@ -677,16 +610,12 @@ class UUIDTextMixin:
Strip hyphens from a value when filtering a UUIDField on backends without
a native datatype for UUID.
"""
def process_rhs(self, qn, connection):
if not connection.features.has_native_uuid_field:
from django.db.models.functions import Replace
if self.rhs_is_direct_value():
self.rhs = Value(self.rhs)
self.rhs = Replace(
self.rhs, Value("-"), Value(""), output_field=CharField()
)
self.rhs = Replace(self.rhs, Value('-'), Value(''), output_field=CharField())
rhs, params = super().process_rhs(qn, connection)
return rhs, params
@@ -33,7 +33,7 @@ class BaseManager:
def __str__(self):
"""Return "app_label.model_label.manager_name"."""
return "%s.%s" % (self.model._meta.label, self.name)
return '%s.%s' % (self.model._meta.label, self.name)
def __class_getitem__(cls, *args, **kwargs):
return cls
@@ -46,12 +46,12 @@ class BaseManager:
Raise a ValueError if the manager is dynamically generated.
"""
qs_class = self._queryset_class
if getattr(self, "_built_with_as_manager", False):
if getattr(self, '_built_with_as_manager', False):
# using MyQuerySet.as_manager()
return (
True, # as_manager
None, # manager_class
"%s.%s" % (qs_class.__module__, qs_class.__name__), # qs_class
'%s.%s' % (qs_class.__module__, qs_class.__name__), # qs_class
None, # args
None, # kwargs
)
@@ -69,7 +69,7 @@ class BaseManager:
)
return (
False, # as_manager
"%s.%s" % (module_name, name), # manager_class
'%s.%s' % (module_name, name), # manager_class
None, # qs_class
self._constructor_args[0], # args
self._constructor_args[1], # kwargs
@@ -83,22 +83,18 @@ class BaseManager:
def create_method(name, method):
def manager_method(self, *args, **kwargs):
return getattr(self.get_queryset(), name)(*args, **kwargs)
manager_method.__name__ = method.__name__
manager_method.__doc__ = method.__doc__
return manager_method
new_methods = {}
for name, method in inspect.getmembers(
queryset_class, predicate=inspect.isfunction
):
for name, method in inspect.getmembers(queryset_class, predicate=inspect.isfunction):
# Only copy missing methods.
if hasattr(cls, name):
continue
# Only copy public methods or methods with the attribute
# queryset_only=False.
queryset_only = getattr(method, "queryset_only", None)
if queryset_only or (queryset_only is None and name.startswith("_")):
# Only copy public methods or methods with the attribute `queryset_only=False`.
queryset_only = getattr(method, 'queryset_only', None)
if queryset_only or (queryset_only is None and name.startswith('_')):
continue
# Copy the method onto the manager.
new_methods[name] = create_method(name, method)
@@ -107,15 +103,11 @@ class BaseManager:
@classmethod
def from_queryset(cls, queryset_class, class_name=None):
if class_name is None:
class_name = "%sFrom%s" % (cls.__name__, queryset_class.__name__)
return type(
class_name,
(cls,),
{
"_queryset_class": queryset_class,
**cls._get_queryset_methods(queryset_class),
},
)
class_name = '%sFrom%s' % (cls.__name__, queryset_class.__name__)
return type(class_name, (cls,), {
'_queryset_class': queryset_class,
**cls._get_queryset_methods(queryset_class),
})
def contribute_to_class(self, cls, name):
self.name = self.name or name
@@ -165,8 +157,8 @@ class BaseManager:
def __eq__(self, other):
return (
isinstance(other, self.__class__)
and self._constructor_args == other._constructor_args
isinstance(other, self.__class__) and
self._constructor_args == other._constructor_args
)
def __hash__(self):
@@ -178,24 +170,22 @@ class Manager(BaseManager.from_queryset(QuerySet)):
class ManagerDescriptor:
def __init__(self, manager):
self.manager = manager
def __get__(self, instance, cls=None):
if instance is not None:
raise AttributeError(
"Manager isn't accessible via %s instances" % cls.__name__
)
raise AttributeError("Manager isn't accessible via %s instances" % cls.__name__)
if cls._meta.abstract:
raise AttributeError(
"Manager isn't available; %s is abstract" % (cls._meta.object_name,)
)
raise AttributeError("Manager isn't available; %s is abstract" % (
cls._meta.object_name,
))
if cls._meta.swapped:
raise AttributeError(
"Manager isn't available; '%s' has been swapped for '%s'"
% (
"Manager isn't available; '%s' has been swapped for '%s'" % (
cls._meta.label,
cls._meta.swapped,
)
@@ -20,37 +20,18 @@ PROXY_PARENTS = object()
EMPTY_RELATION_TREE = ()
IMMUTABLE_WARNING = (
"The return type of '%s' should never be mutated. If you want to manipulate this "
"list for your own use, make a copy first."
"The return type of '%s' should never be mutated. If you want to manipulate this list "
"for your own use, make a copy first."
)
DEFAULT_NAMES = (
"verbose_name",
"verbose_name_plural",
"db_table",
"ordering",
"unique_together",
"permissions",
"get_latest_by",
"order_with_respect_to",
"app_label",
"db_tablespace",
"abstract",
"managed",
"proxy",
"swappable",
"auto_created",
"index_together",
"apps",
"default_permissions",
"select_on_save",
"default_related_name",
"required_db_features",
"required_db_vendor",
"base_manager_name",
"default_manager_name",
"indexes",
"constraints",
'verbose_name', 'verbose_name_plural', 'db_table', 'ordering',
'unique_together', 'permissions', 'get_latest_by', 'order_with_respect_to',
'app_label', 'db_tablespace', 'abstract', 'managed', 'proxy', 'swappable',
'auto_created', 'index_together', 'apps', 'default_permissions',
'select_on_save', 'default_related_name', 'required_db_features',
'required_db_vendor', 'base_manager_name', 'default_manager_name',
'indexes', 'constraints',
)
@@ -82,17 +63,11 @@ def make_immutable_fields_list(name, data):
class Options:
FORWARD_PROPERTIES = {
"fields",
"many_to_many",
"concrete_fields",
"local_concrete_fields",
"_forward_fields_map",
"managers",
"managers_map",
"base_manager",
"default_manager",
'fields', 'many_to_many', 'concrete_fields', 'local_concrete_fields',
'_forward_fields_map', 'managers', 'managers_map', 'base_manager',
'default_manager',
}
REVERSE_PROPERTIES = {"related_objects", "fields_map", "_relation_tree"}
REVERSE_PROPERTIES = {'related_objects', 'fields_map', '_relation_tree'}
default_apps = apps
@@ -107,7 +82,7 @@ class Options:
self.model_name = None
self.verbose_name = None
self.verbose_name_plural = None
self.db_table = ""
self.db_table = ''
self.ordering = []
self._ordering_clash = False
self.indexes = []
@@ -115,7 +90,7 @@ class Options:
self.unique_together = []
self.index_together = []
self.select_on_save = False
self.default_permissions = ("add", "change", "delete", "view")
self.default_permissions = ('add', 'change', 'delete', 'view')
self.permissions = []
self.object_name = None
self.app_label = app_label
@@ -155,11 +130,11 @@ class Options:
@property
def label(self):
return "%s.%s" % (self.app_label, self.object_name)
return '%s.%s' % (self.app_label, self.object_name)
@property
def label_lower(self):
return "%s.%s" % (self.app_label, self.model_name)
return '%s.%s' % (self.app_label, self.model_name)
@property
def app_config(self):
@@ -192,7 +167,7 @@ class Options:
# Ignore any private attributes that Django doesn't care about.
# NOTE: We can't modify a dictionary's contents while looping
# over it, so we loop over the *original* dictionary instead.
if name.startswith("_"):
if name.startswith('_'):
del meta_attrs[name]
for attr_name in DEFAULT_NAMES:
if attr_name in meta_attrs:
@@ -206,34 +181,30 @@ class Options:
self.index_together = normalize_together(self.index_together)
# App label/class name interpolation for names of constraints and
# indexes.
if not getattr(cls._meta, "abstract", False):
for attr_name in {"constraints", "indexes"}:
if not getattr(cls._meta, 'abstract', False):
for attr_name in {'constraints', 'indexes'}:
objs = getattr(self, attr_name, [])
setattr(self, attr_name, self._format_names_with_class(cls, objs))
# verbose_name_plural is a special case because it uses a 's'
# by default.
if self.verbose_name_plural is None:
self.verbose_name_plural = format_lazy("{}s", self.verbose_name)
self.verbose_name_plural = format_lazy('{}s', self.verbose_name)
# order_with_respect_and ordering are mutually exclusive.
self._ordering_clash = bool(self.ordering and self.order_with_respect_to)
# Any leftover attributes must be invalid.
if meta_attrs != {}:
raise TypeError(
"'class Meta' got invalid attribute(s): %s" % ",".join(meta_attrs)
)
raise TypeError("'class Meta' got invalid attribute(s): %s" % ','.join(meta_attrs))
else:
self.verbose_name_plural = format_lazy("{}s", self.verbose_name)
self.verbose_name_plural = format_lazy('{}s', self.verbose_name)
del self.meta
# If the db_table wasn't provided, use the app_label + model_name.
if not self.db_table:
self.db_table = "%s_%s" % (self.app_label, self.model_name)
self.db_table = truncate_name(
self.db_table, connection.ops.max_name_length()
)
self.db_table = truncate_name(self.db_table, connection.ops.max_name_length())
def _format_names_with_class(self, cls, objs):
"""App label/class name interpolation for object names."""
@@ -241,8 +212,8 @@ class Options:
for obj in objs:
obj = obj.clone()
obj.name = obj.name % {
"app_label": cls._meta.app_label.lower(),
"class": cls.__name__.lower(),
'app_label': cls._meta.app_label.lower(),
'class': cls.__name__.lower(),
}
new_objs.append(obj)
return new_objs
@@ -250,19 +221,19 @@ class Options:
def _get_default_pk_class(self):
pk_class_path = getattr(
self.app_config,
"default_auto_field",
'default_auto_field',
settings.DEFAULT_AUTO_FIELD,
)
if self.app_config and self.app_config._is_default_auto_field_overridden:
app_config_class = type(self.app_config)
source = (
f"{app_config_class.__module__}."
f"{app_config_class.__qualname__}.default_auto_field"
f'{app_config_class.__module__}.'
f'{app_config_class.__qualname__}.default_auto_field'
)
else:
source = "DEFAULT_AUTO_FIELD"
source = 'DEFAULT_AUTO_FIELD'
if not pk_class_path:
raise ImproperlyConfigured(f"{source} must not be empty.")
raise ImproperlyConfigured(f'{source} must not be empty.')
try:
pk_class = import_string(pk_class_path)
except ImportError as e:
@@ -285,20 +256,15 @@ class Options:
query = self.order_with_respect_to
try:
self.order_with_respect_to = next(
f
for f in self._get_fields(reverse=False)
f for f in self._get_fields(reverse=False)
if f.name == query or f.attname == query
)
except StopIteration:
raise FieldDoesNotExist(
"%s has no field named '%s'" % (self.object_name, query)
)
raise FieldDoesNotExist("%s has no field named '%s'" % (self.object_name, query))
self.ordering = ("_order",)
if not any(
isinstance(field, OrderWrt) for field in model._meta.local_fields
):
model.add_to_class("_order", OrderWrt())
self.ordering = ('_order',)
if not any(isinstance(field, OrderWrt) for field in model._meta.local_fields):
model.add_to_class('_order', OrderWrt())
else:
self.order_with_respect_to = None
@@ -310,17 +276,15 @@ class Options:
# Look for a local field with the same name as the
# first parent link. If a local field has already been
# created, use it instead of promoting the parent
already_created = [
fld for fld in self.local_fields if fld.name == field.name
]
already_created = [fld for fld in self.local_fields if fld.name == field.name]
if already_created:
field = already_created[0]
field.primary_key = True
self.setup_pk(field)
else:
pk_class = self._get_default_pk_class()
auto = pk_class(verbose_name="ID", primary_key=True, auto_created=True)
model.add_to_class("id", auto)
auto = pk_class(verbose_name='ID', primary_key=True, auto_created=True)
model.add_to_class('id', auto)
def add_manager(self, manager):
self.local_managers.append(manager)
@@ -347,11 +311,7 @@ class Options:
# ideally, we'd just ask for field.related_model. However, related_model
# is a cached property, and all the models haven't been loaded yet, so
# we need to make sure we don't cache a string reference.
if (
field.is_relation
and hasattr(field.remote_field, "model")
and field.remote_field.model
):
if field.is_relation and hasattr(field.remote_field, 'model') and field.remote_field.model:
try:
field.remote_field.model._meta._expire_cache(forward=False)
except AttributeError:
@@ -375,7 +335,7 @@ class Options:
self.db_table = target._meta.db_table
def __repr__(self):
return "<Options for %s>" % self.object_name
return '<Options for %s>' % self.object_name
def __str__(self):
return self.label_lower
@@ -392,10 +352,8 @@ class Options:
if self.required_db_vendor:
return self.required_db_vendor == connection.vendor
if self.required_db_features:
return all(
getattr(connection.features, feat, False)
for feat in self.required_db_features
)
return all(getattr(connection.features, feat, False)
for feat in self.required_db_features)
return True
@property
@@ -417,7 +375,7 @@ class Options:
swapped_for = getattr(settings, self.swappable, None)
if swapped_for:
try:
swapped_label, swapped_object = swapped_for.split(".")
swapped_label, swapped_object = swapped_for.split('.')
except ValueError:
# setting not in the format app_label.model_name
# raising ImproperlyConfigured here causes problems with
@@ -425,10 +383,7 @@ class Options:
# or as part of validation.
return swapped_for
if (
"%s.%s" % (swapped_label, swapped_object.lower())
!= self.label_lower
):
if '%s.%s' % (swapped_label, swapped_object.lower()) != self.label_lower:
return swapped_for
return None
@@ -436,7 +391,7 @@ class Options:
def managers(self):
managers = []
seen_managers = set()
bases = (b for b in self.model.mro() if hasattr(b, "_meta"))
bases = (b for b in self.model.mro() if hasattr(b, '_meta'))
for depth, base in enumerate(bases):
for manager in base._meta.local_managers:
if manager.name in seen_managers:
@@ -462,8 +417,8 @@ class Options:
if not base_manager_name:
# Get the first parent's base_manager_name if there's one.
for parent in self.model.mro()[1:]:
if hasattr(parent, "_meta"):
if parent._base_manager.name != "_base_manager":
if hasattr(parent, '_meta'):
if parent._base_manager.name != '_base_manager':
base_manager_name = parent._base_manager.name
break
@@ -472,15 +427,14 @@ class Options:
return self.managers_map[base_manager_name]
except KeyError:
raise ValueError(
"%s has no manager named %r"
% (
"%s has no manager named %r" % (
self.object_name,
base_manager_name,
)
)
manager = Manager()
manager.name = "_base_manager"
manager.name = '_base_manager'
manager.model = self.model
manager.auto_created = True
return manager
@@ -491,7 +445,7 @@ class Options:
if not default_manager_name and not self.local_managers:
# Get the first parent's default_manager_name if there's one.
for parent in self.model.mro()[1:]:
if hasattr(parent, "_meta"):
if hasattr(parent, '_meta'):
default_manager_name = parent._meta.default_manager_name
break
@@ -500,8 +454,7 @@ class Options:
return self.managers_map[default_manager_name]
except KeyError:
raise ValueError(
"%s has no manager named %r"
% (
"%s has no manager named %r" % (
self.object_name,
default_manager_name,
)
@@ -535,20 +488,13 @@ class Options:
def is_not_a_generic_foreign_key(f):
return not (
f.is_relation
and f.many_to_one
and not (hasattr(f.remote_field, "model") and f.remote_field.model)
f.is_relation and f.many_to_one and not (hasattr(f.remote_field, 'model') and f.remote_field.model)
)
return make_immutable_fields_list(
"fields",
(
f
for f in self._get_fields(reverse=False)
if is_not_an_m2m_field(f)
and is_not_a_generic_relation(f)
and is_not_a_generic_foreign_key(f)
),
(f for f in self._get_fields(reverse=False)
if is_not_an_m2m_field(f) and is_not_a_generic_relation(f) and is_not_a_generic_foreign_key(f))
)
@cached_property
@@ -588,11 +534,7 @@ class Options:
"""
return make_immutable_fields_list(
"many_to_many",
(
f
for f in self._get_fields(reverse=False)
if f.is_relation and f.many_to_many
),
(f for f in self._get_fields(reverse=False) if f.is_relation and f.many_to_many)
)
@cached_property
@@ -606,16 +548,10 @@ class Options:
combined with filtering of field properties is the public API for
obtaining this field list.
"""
all_related_fields = self._get_fields(
forward=False, reverse=True, include_hidden=True
)
all_related_fields = self._get_fields(forward=False, reverse=True, include_hidden=True)
return make_immutable_fields_list(
"related_objects",
(
obj
for obj in all_related_fields
if not obj.hidden or obj.field.many_to_many
),
(obj for obj in all_related_fields if not obj.hidden or obj.field.many_to_many)
)
@cached_property
@@ -671,9 +607,7 @@ class Options:
# field map.
return self.fields_map[field_name]
except KeyError:
raise FieldDoesNotExist(
"%s has no field named '%s'" % (self.object_name, field_name)
)
raise FieldDoesNotExist("%s has no field named '%s'" % (self.object_name, field_name))
def get_base_chain(self, model):
"""
@@ -742,17 +676,15 @@ class Options:
final_field = opts.parents[int_model]
targets = (final_field.remote_field.get_related_field(),)
opts = int_model._meta
path.append(
PathInfo(
from_opts=final_field.model._meta,
to_opts=opts,
target_fields=targets,
join_field=final_field,
m2m=False,
direct=True,
filtered_relation=None,
)
)
path.append(PathInfo(
from_opts=final_field.model._meta,
to_opts=opts,
target_fields=targets,
join_field=final_field,
m2m=False,
direct=True,
filtered_relation=None,
))
return path
def get_path_from_parent(self, parent):
@@ -794,8 +726,7 @@ class Options:
if opts.abstract:
continue
fields_with_relations = (
f
for f in opts._get_fields(reverse=False, include_parents=False)
f for f in opts._get_fields(reverse=False, include_parents=False)
if f.is_relation and f.related_model is not None
)
for f in fields_with_relations:
@@ -809,13 +740,11 @@ class Options:
# __dict__ takes precedence over a data descriptor (such as
# @cached_property). This means that the _meta._relation_tree is
# only called if related_objects is not in __dict__.
related_objects = related_objects_graph[
model._meta.concrete_model._meta.label
]
model._meta.__dict__["_relation_tree"] = related_objects
related_objects = related_objects_graph[model._meta.concrete_model._meta.label]
model._meta.__dict__['_relation_tree'] = related_objects
# It seems it is possible that self is not in all_models, so guard
# against that with default for get().
return self.__dict__.get("_relation_tree", EMPTY_RELATION_TREE)
return self.__dict__.get('_relation_tree', EMPTY_RELATION_TREE)
@cached_property
def _relation_tree(self):
@@ -846,18 +775,10 @@ class Options:
"""
if include_parents is False:
include_parents = PROXY_PARENTS
return self._get_fields(
include_parents=include_parents, include_hidden=include_hidden
)
return self._get_fields(include_parents=include_parents, include_hidden=include_hidden)
def _get_fields(
self,
forward=True,
reverse=True,
include_parents=True,
include_hidden=False,
seen_models=None,
):
def _get_fields(self, forward=True, reverse=True, include_parents=True, include_hidden=False,
seen_models=None):
"""
Internal helper function to return fields of the model.
* If forward=True, then fields defined on this model are returned.
@@ -870,9 +791,7 @@ class Options:
parent chain to the model's concrete model.
"""
if include_parents not in (True, False, PROXY_PARENTS):
raise TypeError(
"Invalid argument for include_parents: %s" % (include_parents,)
)
raise TypeError("Invalid argument for include_parents: %s" % (include_parents,))
# This helper function is used to allow recursion in ``get_fields()``
# implementation and to provide a fast way for Django's internals to
# access specific subsets of fields.
@@ -904,22 +823,13 @@ class Options:
# fields from the same parent again.
if parent in seen_models:
continue
if (
parent._meta.concrete_model != self.concrete_model
and include_parents == PROXY_PARENTS
):
if (parent._meta.concrete_model != self.concrete_model and
include_parents == PROXY_PARENTS):
continue
for obj in parent._meta._get_fields(
forward=forward,
reverse=reverse,
include_parents=include_parents,
include_hidden=include_hidden,
seen_models=seen_models,
):
if (
not getattr(obj, "parent_link", False)
or obj.model == self.concrete_model
):
forward=forward, reverse=reverse, include_parents=include_parents,
include_hidden=include_hidden, seen_models=seen_models):
if not getattr(obj, 'parent_link', False) or obj.model == self.concrete_model:
fields.append(obj)
if reverse and not self.proxy:
# Tree is computed once and cached until the app cache is expired.
@@ -960,11 +870,7 @@ class Options:
return [
constraint
for constraint in self.constraints
if (
isinstance(constraint, UniqueConstraint)
and constraint.condition is None
and not constraint.contains_expressions
)
if isinstance(constraint, UniqueConstraint) and constraint.condition is None
]
@cached_property
@@ -984,9 +890,6 @@ class Options:
Fields to be returned after a database insert.
"""
return [
field
for field in self._get_fields(
forward=True, reverse=False, include_parents=PROXY_PARENTS
)
if getattr(field, "db_returning", False)
field for field in self._get_fields(forward=True, reverse=False, include_parents=PROXY_PARENTS)
if getattr(field, 'db_returning', False)
]
File diff suppressed because it is too large Load Diff
@@ -8,19 +8,44 @@ circular import difficulties.
import copy
import functools
import inspect
import warnings
from collections import namedtuple
from django.core.exceptions import FieldError
from django.core.exceptions import FieldDoesNotExist, FieldError
from django.db.models.constants import LOOKUP_SEP
from django.utils import tree
from django.utils.deprecation import RemovedInDjango40Warning
# PathInfo is used when converting lookups (fk__somecol). The contents
# describe the relation in Model terms (model Options and Fields for both
# sides of the relation. The join_field is the field backing the relation.
PathInfo = namedtuple(
"PathInfo",
"from_opts to_opts target_fields join_field m2m direct filtered_relation",
)
PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation')
class InvalidQueryType(type):
@property
def _subclasses(self):
return (FieldDoesNotExist, FieldError)
def __warn(self):
warnings.warn(
'The InvalidQuery exception class is deprecated. Use '
'FieldDoesNotExist or FieldError instead.',
category=RemovedInDjango40Warning,
stacklevel=4,
)
def __instancecheck__(self, instance):
self.__warn()
return isinstance(instance, self._subclasses) or super().__instancecheck__(instance)
def __subclasscheck__(self, subclass):
self.__warn()
return issubclass(subclass, self._subclasses) or super().__subclasscheck__(subclass)
class InvalidQuery(Exception, metaclass=InvalidQueryType):
pass
def subclasses(cls):
@@ -34,26 +59,21 @@ class Q(tree.Node):
Encapsulate filters as objects that can then be combined logically (using
`&` and `|`).
"""
# Connection types
AND = "AND"
OR = "OR"
AND = 'AND'
OR = 'OR'
default = AND
conditional = True
def __init__(self, *args, _connector=None, _negated=False, **kwargs):
super().__init__(
children=[*args, *sorted(kwargs.items())],
connector=_connector,
negated=_negated,
)
super().__init__(children=[*args, *sorted(kwargs.items())], connector=_connector, negated=_negated)
def _combine(self, other, conn):
if not (isinstance(other, Q) or getattr(other, "conditional", False) is True):
if not(isinstance(other, Q) or getattr(other, 'conditional', False) is True):
raise TypeError(other)
if not self:
return other.copy() if hasattr(other, "copy") else copy.copy(other)
return other.copy() if hasattr(other, 'copy') else copy.copy(other)
elif isinstance(other, Q) and not other:
_, args, kwargs = self.deconstruct()
return type(self)(*args, **kwargs)
@@ -76,31 +96,26 @@ class Q(tree.Node):
obj.negate()
return obj
def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
# We must promote any new joins to left outer joins so that when Q is
# used as an expression, rows aren't filtered due to joins.
clause, joins = query._add_q(
self,
reuse,
allow_joins=allow_joins,
split_subq=False,
self, reuse, allow_joins=allow_joins, split_subq=False,
check_filterable=False,
)
query.promote_joins(joins)
return clause
def deconstruct(self):
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
if path.startswith("django.db.models.query_utils"):
path = path.replace("django.db.models.query_utils", "django.db.models")
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
if path.startswith('django.db.models.query_utils'):
path = path.replace('django.db.models.query_utils', 'django.db.models')
args = tuple(self.children)
kwargs = {}
if self.connector != self.default:
kwargs["_connector"] = self.connector
kwargs['_connector'] = self.connector
if self.negated:
kwargs["_negated"] = True
kwargs['_negated'] = True
return path, args, kwargs
@@ -109,7 +124,6 @@ class DeferredAttribute:
A wrapper for a deferred-loading field. When the value is read from this
object the first time, the query is executed.
"""
def __init__(self, field):
self.field = field
@@ -146,6 +160,7 @@ class DeferredAttribute:
class RegisterLookupMixin:
@classmethod
def _get_lookup(cls, lookup_name):
return cls.get_lookups().get(lookup_name, None)
@@ -153,16 +168,13 @@ class RegisterLookupMixin:
@classmethod
@functools.lru_cache(maxsize=None)
def get_lookups(cls):
class_lookups = [
parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
]
class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in inspect.getmro(cls)]
return cls.merge_dicts(class_lookups)
def get_lookup(self, lookup_name):
from django.db.models.lookups import Lookup
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, "output_field"):
if found is None and hasattr(self, 'output_field'):
return self.output_field.get_lookup(lookup_name)
if found is not None and not issubclass(found, Lookup):
return None
@@ -170,9 +182,8 @@ class RegisterLookupMixin:
def get_transform(self, lookup_name):
from django.db.models.lookups import Transform
found = self._get_lookup(lookup_name)
if found is None and hasattr(self, "output_field"):
if found is None and hasattr(self, 'output_field'):
return self.output_field.get_transform(lookup_name)
if found is not None and not issubclass(found, Transform):
return None
@@ -198,7 +209,7 @@ class RegisterLookupMixin:
def register_lookup(cls, lookup, lookup_name=None):
if lookup_name is None:
lookup_name = lookup.lookup_name
if "class_lookups" not in cls.__dict__:
if 'class_lookups' not in cls.__dict__:
cls.class_lookups = {}
cls.class_lookups[lookup_name] = lookup
cls._clear_cached_lookups()
@@ -245,8 +256,8 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
if field.attname not in load_fields:
if restricted and field.name in requested:
msg = (
"Field %s.%s cannot be both deferred and traversed using "
"select_related at the same time."
'Field %s.%s cannot be both deferred and traversed using '
'select_related at the same time.'
) % (field.model._meta.object_name, field.name)
raise FieldError(msg)
return True
@@ -272,14 +283,12 @@ def check_rel_lookup_compatibility(model, target_opts, field):
1) model and opts match (where proxy inheritance is removed)
2) model is parent of opts' model or the other way around
"""
def check(opts):
return (
model._meta.concrete_model == opts.concrete_model
or opts.concrete_model in model._meta.get_parent_list()
or model in opts.get_parent_list()
model._meta.concrete_model == opts.concrete_model or
opts.concrete_model in model._meta.get_parent_list() or
model in opts.get_parent_list()
)
# If the field is a primary key, then doing a query against the field's
# model is ok, too. Consider the case:
# class Restaurant(models.Model):
@@ -289,8 +298,9 @@ def check_rel_lookup_compatibility(model, target_opts, field):
# give Place's opts as the target opts, but Restaurant isn't compatible
# with that. This logic applies only to primary keys, as when doing __in=qs,
# we are going to turn this into __in=qs.values('pk') later on.
return check(target_opts) or (
getattr(field, "primary_key", False) and check(field.model._meta)
return (
check(target_opts) or
(getattr(field, 'primary_key', False) and check(field.model._meta))
)
@@ -299,11 +309,11 @@ class FilteredRelation:
def __init__(self, relation_name, *, condition=Q()):
if not relation_name:
raise ValueError("relation_name cannot be empty.")
raise ValueError('relation_name cannot be empty.')
self.relation_name = relation_name
self.alias = None
if not isinstance(condition, Q):
raise ValueError("condition argument must be a Q() instance.")
raise ValueError('condition argument must be a Q() instance.')
self.condition = condition
self.path = []
@@ -311,9 +321,9 @@ class FilteredRelation:
if not isinstance(other, self.__class__):
return NotImplemented
return (
self.relation_name == other.relation_name
and self.alias == other.alias
and self.condition == other.condition
self.relation_name == other.relation_name and
self.alias == other.alias and
self.condition == other.condition
)
def clone(self):
@@ -327,7 +337,7 @@ class FilteredRelation:
QuerySet.annotate() only accepts expression-like arguments
(with a resolve_expression() method).
"""
raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
raise NotImplementedError('FilteredRelation.resolve_expression() is unused.')
def as_sql(self, compiler, connection):
# Resolve the condition in Join.filtered_relation.
@@ -11,7 +11,6 @@ class ModelSignal(Signal):
Signal subclass that allows the sender to be lazily specified as a string
of the `app_label.ModelName` form.
"""
def _lazy_method(self, method, apps, receiver, sender, **kwargs):
from django.db.models.options import Options
@@ -25,12 +24,8 @@ class ModelSignal(Signal):
def connect(self, receiver, sender=None, weak=True, dispatch_uid=None, apps=None):
self._lazy_method(
super().connect,
apps,
receiver,
sender,
weak=weak,
dispatch_uid=dispatch_uid,
super().connect, apps, receiver, sender,
weak=weak, dispatch_uid=dispatch_uid,
)
def disconnect(self, receiver=None, sender=None, dispatch_uid=None, apps=None):
@@ -3,4 +3,4 @@ from django.db.models.sql.query import Query
from django.db.models.sql.subqueries import * # NOQA
from django.db.models.sql.where import AND, OR
__all__ = ["Query", "AND", "OR"]
__all__ = ['Query', 'AND', 'OR']
File diff suppressed because it is too large Load Diff
@@ -1,6 +1,7 @@
"""
Constants specific to the SQL storage portion of the ORM.
"""
from django.utils.regex_helper import _lazy_re_compile
# Size of each "chunk" for get_iterator calls.
# Larger values are slightly faster at the expense of more storage space.
@@ -9,16 +10,17 @@ GET_ITERATOR_CHUNK_SIZE = 100
# Namedtuples for sql.* internal use.
# How many results to expect from a cursor.execute call
MULTI = "multi"
SINGLE = "single"
CURSOR = "cursor"
NO_RESULTS = "no results"
MULTI = 'multi'
SINGLE = 'single'
CURSOR = 'cursor'
NO_RESULTS = 'no results'
ORDER_DIR = {
"ASC": ("ASC", "DESC"),
"DESC": ("DESC", "ASC"),
'ASC': ('ASC', 'DESC'),
'DESC': ('DESC', 'ASC'),
}
ORDER_PATTERN = _lazy_re_compile(r'[-+]?[.\w]+$')
# SQL join types.
INNER = "INNER JOIN"
LOUTER = "LEFT OUTER JOIN"
INNER = 'INNER JOIN'
LOUTER = 'LEFT OUTER JOIN'
@@ -11,7 +11,6 @@ class MultiJoin(Exception):
multi-valued join was attempted (if the caller wants to treat that
exceptionally).
"""
def __init__(self, names_pos, path_with_names):
self.level = names_pos
# The path travelled, this includes the path to the multijoin.
@@ -26,8 +25,7 @@ class Join:
"""
Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
FROM entry. For example, the SQL generated could be
LEFT OUTER JOIN "sometable" T1
ON ("othertable"."sometable_id" = "sometable"."id")
LEFT OUTER JOIN "sometable" T1 ON ("othertable"."sometable_id" = "sometable"."id")
This class is primarily used in Query.alias_map. All entries in alias_map
must be Join compatible by providing the following attributes and methods:
@@ -40,17 +38,8 @@ class Join:
- as_sql()
- relabeled_clone()
"""
def __init__(
self,
table_name,
parent_alias,
table_alias,
join_type,
join_field,
nullable,
filtered_relation=None,
):
def __init__(self, table_name, parent_alias, table_alias, join_type,
join_field, nullable, filtered_relation=None):
# Join table
self.table_name = table_name
self.parent_alias = parent_alias
@@ -80,47 +69,36 @@ class Join:
# Add a join condition for each pair of joining columns.
for lhs_col, rhs_col in self.join_cols:
join_conditions.append(
"%s.%s = %s.%s"
% (
qn(self.parent_alias),
qn2(lhs_col),
qn(self.table_alias),
qn2(rhs_col),
)
)
join_conditions.append('%s.%s = %s.%s' % (
qn(self.parent_alias),
qn2(lhs_col),
qn(self.table_alias),
qn2(rhs_col),
))
# Add a single condition inside parentheses for whatever
# get_extra_restriction() returns.
extra_cond = self.join_field.get_extra_restriction(
self.table_alias, self.parent_alias
)
compiler.query.where_class, self.table_alias, self.parent_alias)
if extra_cond:
extra_sql, extra_params = compiler.compile(extra_cond)
join_conditions.append("(%s)" % extra_sql)
join_conditions.append('(%s)' % extra_sql)
params.extend(extra_params)
if self.filtered_relation:
extra_sql, extra_params = compiler.compile(self.filtered_relation)
if extra_sql:
join_conditions.append("(%s)" % extra_sql)
join_conditions.append('(%s)' % extra_sql)
params.extend(extra_params)
if not join_conditions:
# This might be a rel on the other end of an actual declared field.
declared_field = getattr(self.join_field, "field", self.join_field)
declared_field = getattr(self.join_field, 'field', self.join_field)
raise ValueError(
"Join generated an empty ON clause. %s did not yield either "
"joining columns or extra restrictions." % declared_field.__class__
)
on_clause_sql = " AND ".join(join_conditions)
alias_str = (
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
)
sql = "%s %s%s ON (%s)" % (
self.join_type,
qn(self.table_name),
alias_str,
on_clause_sql,
)
on_clause_sql = ' AND '.join(join_conditions)
alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
sql = '%s %s%s ON (%s)' % (self.join_type, qn(self.table_name), alias_str, on_clause_sql)
return sql, params
def relabeled_clone(self, change_map):
@@ -128,19 +106,12 @@ class Join:
new_table_alias = change_map.get(self.table_alias, self.table_alias)
if self.filtered_relation is not None:
filtered_relation = self.filtered_relation.clone()
filtered_relation.path = [
change_map.get(p, p) for p in self.filtered_relation.path
]
filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path]
else:
filtered_relation = None
return self.__class__(
self.table_name,
new_parent_alias,
new_table_alias,
self.join_type,
self.join_field,
self.nullable,
filtered_relation=filtered_relation,
self.table_name, new_parent_alias, new_table_alias, self.join_type,
self.join_field, self.nullable, filtered_relation=filtered_relation,
)
@property
@@ -161,8 +132,9 @@ class Join:
def __hash__(self):
return hash(self.identity)
def equals(self, other):
# Ignore filtered_relation in equality check.
def equals(self, other, with_filtered_relation):
if with_filtered_relation:
return self == other
return self.identity[:-1] == other.identity[:-1]
def demote(self):
@@ -183,7 +155,6 @@ class BaseTable:
SELECT * FROM "foo" WHERE somecond
could be generated by this class.
"""
join_type = None
parent_alias = None
filtered_relation = None
@@ -193,16 +164,12 @@ class BaseTable:
self.table_alias = alias
def as_sql(self, compiler, connection):
alias_str = (
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
)
alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
base_sql = compiler.quote_name_unless_alias(self.table_name)
return base_sql + alias_str, []
def relabeled_clone(self, change_map):
return self.__class__(
self.table_name, change_map.get(self.table_alias, self.table_alias)
)
return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
@property
def identity(self):
@@ -216,5 +183,5 @@ class BaseTable:
def __hash__(self):
return hash(self.identity)
def equals(self, other):
def equals(self, other, with_filtered_relation):
return self.identity == other.identity
File diff suppressed because it is too large Load Diff
@@ -3,16 +3,19 @@ Query subclasses which provide extra functionality beyond simple data retrieval.
"""
from django.core.exceptions import FieldError
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
from django.db.models.query_utils import Q
from django.db.models.sql.constants import (
CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS,
)
from django.db.models.sql.query import Query
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'AggregateQuery']
class DeleteQuery(Query):
"""A DELETE SQL query."""
compiler = "SQLDeleteCompiler"
compiler = 'SQLDeleteCompiler'
def do_query(self, table, where, using):
self.alias_map = {table: self.alias_map[table]}
@@ -34,21 +37,17 @@ class DeleteQuery(Query):
num_deleted = 0
field = self.get_meta().pk
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.clear_where()
self.add_filter(
f"{field.attname}__in",
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
)
num_deleted += self.do_query(
self.get_meta().db_table, self.where, using=using
)
self.where = self.where_class()
self.add_q(Q(
**{field.attname + '__in': pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]}))
num_deleted += self.do_query(self.get_meta().db_table, self.where, using=using)
return num_deleted
class UpdateQuery(Query):
"""An UPDATE SQL query."""
compiler = "SQLUpdateCompiler"
compiler = 'SQLUpdateCompiler'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -71,10 +70,8 @@ class UpdateQuery(Query):
def update_batch(self, pk_list, values, using):
self.add_update_values(values)
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
self.clear_where()
self.add_filter(
"pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
)
self.where = self.where_class()
self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]))
self.get_compiler(using).execute_sql(NO_RESULTS)
def add_update_values(self, values):
@@ -86,14 +83,12 @@ class UpdateQuery(Query):
values_seq = []
for name, val in values.items():
field = self.get_meta().get_field(name)
direct = (
not (field.auto_created and not field.concrete) or not field.concrete
)
direct = not (field.auto_created and not field.concrete) or not field.concrete
model = field.model._meta.concrete_model
if not direct or (field.is_relation and field.many_to_many):
raise FieldError(
"Cannot update model field %r (only non-relations and "
"foreign keys permitted)." % field
'Cannot update model field %r (only non-relations and '
'foreign keys permitted).' % field
)
if model is not self.get_meta().concrete_model:
self.add_related_update(model, field, val)
@@ -108,7 +103,7 @@ class UpdateQuery(Query):
called add_update_targets() to hint at the extra information here.
"""
for field, model, val in values_seq:
if hasattr(val, "resolve_expression"):
if hasattr(val, 'resolve_expression'):
# Resolve expressions here so that annotations are no longer needed
val = val.resolve_expression(self, allow_joins=False, for_save=True)
self.values.append((field, model, val))
@@ -134,13 +129,13 @@ class UpdateQuery(Query):
query = UpdateQuery(model)
query.values = values
if self.related_ids is not None:
query.add_filter("pk__in", self.related_ids)
query.add_filter(('pk__in', self.related_ids))
result.append(query)
return result
class InsertQuery(Query):
compiler = "SQLInsertCompiler"
compiler = 'SQLInsertCompiler'
def __init__(self, *args, ignore_conflicts=False, **kwargs):
super().__init__(*args, **kwargs)
@@ -160,7 +155,7 @@ class AggregateQuery(Query):
elements in the provided list.
"""
compiler = "SQLAggregateCompiler"
compiler = 'SQLAggregateCompiler'
def __init__(self, model, inner_query):
self.inner_query = inner_query
@@ -7,8 +7,8 @@ from django.utils import tree
from django.utils.functional import cached_property
# Connection types
AND = "AND"
OR = "OR"
AND = 'AND'
OR = 'OR'
class WhereNode(tree.Node):
@@ -25,7 +25,6 @@ class WhereNode(tree.Node):
relabeled_clone() method or relabel_aliases() and clone() methods and
contains_aggregate attribute.
"""
default = AND
resolved = False
conditional = True
@@ -41,15 +40,15 @@ class WhereNode(tree.Node):
in_negated = negated ^ self.negated
# If the effective connector is OR and this node contains an aggregate,
# then we need to push the whole branch to HAVING clause.
may_need_split = (in_negated and self.connector == AND) or (
not in_negated and self.connector == OR
)
may_need_split = (
(in_negated and self.connector == AND) or
(not in_negated and self.connector == OR))
if may_need_split and self.contains_aggregate:
return None, self
where_parts = []
having_parts = []
for c in self.children:
if hasattr(c, "split_having"):
if hasattr(c, 'split_having'):
where_part, having_part = c.split_having(in_negated)
if where_part is not None:
where_parts.append(where_part)
@@ -59,16 +58,8 @@ class WhereNode(tree.Node):
having_parts.append(c)
else:
where_parts.append(c)
having_node = (
self.__class__(having_parts, self.connector, self.negated)
if having_parts
else None
)
where_node = (
self.__class__(where_parts, self.connector, self.negated)
if where_parts
else None
)
having_node = self.__class__(having_parts, self.connector, self.negated) if having_parts else None
where_node = self.__class__(where_parts, self.connector, self.negated) if where_parts else None
return where_node, having_node
def as_sql(self, compiler, connection):
@@ -103,24 +94,24 @@ class WhereNode(tree.Node):
# counts.
if empty_needed == 0:
if self.negated:
return "", []
return '', []
else:
raise EmptyResultSet
if full_needed == 0:
if self.negated:
raise EmptyResultSet
else:
return "", []
conn = " %s " % self.connector
return '', []
conn = ' %s ' % self.connector
sql_string = conn.join(result)
if sql_string:
if self.negated:
# Some backends (Oracle at least) need parentheses
# around the inner SQL in the negated case, even if the
# inner SQL contains just a single expression.
sql_string = "NOT (%s)" % sql_string
sql_string = 'NOT (%s)' % sql_string
elif len(result) > 1 or self.resolved:
sql_string = "(%s)" % sql_string
sql_string = '(%s)' % sql_string
return sql_string, result_params
def get_group_by_cols(self, alias=None):
@@ -142,10 +133,10 @@ class WhereNode(tree.Node):
mapping old (current) alias values to the new values.
"""
for pos, child in enumerate(self.children):
if hasattr(child, "relabel_aliases"):
if hasattr(child, 'relabel_aliases'):
# For example another WhereNode
child.relabel_aliases(change_map)
elif hasattr(child, "relabeled_clone"):
elif hasattr(child, 'relabeled_clone'):
self.children[pos] = child.relabeled_clone(change_map)
def clone(self):
@@ -155,12 +146,9 @@ class WhereNode(tree.Node):
value) tuples, or objects supporting .clone().
"""
clone = self.__class__._new_instance(
children=None,
connector=self.connector,
negated=self.negated,
)
children=[], connector=self.connector, negated=self.negated)
for child in self.children:
if hasattr(child, "clone"):
if hasattr(child, 'clone'):
clone.children.append(child.clone())
else:
clone.children.append(child)
@@ -194,20 +182,24 @@ class WhereNode(tree.Node):
def contains_over_clause(self):
return self._contains_over_clause(self)
@property
def is_summary(self):
return any(child.is_summary for child in self.children)
@staticmethod
def _resolve_leaf(expr, query, *args, **kwargs):
if hasattr(expr, "resolve_expression"):
if hasattr(expr, 'resolve_expression'):
expr = expr.resolve_expression(query, *args, **kwargs)
return expr
@classmethod
def _resolve_node(cls, node, query, *args, **kwargs):
if hasattr(node, "children"):
if hasattr(node, 'children'):
for child in node.children:
cls._resolve_node(child, query, *args, **kwargs)
if hasattr(node, "lhs"):
if hasattr(node, 'lhs'):
node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
if hasattr(node, "rhs"):
if hasattr(node, 'rhs'):
node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
def resolve_expression(self, *args, **kwargs):
@@ -216,30 +208,9 @@ class WhereNode(tree.Node):
clone.resolved = True
return clone
@cached_property
def output_field(self):
from django.db.models import BooleanField
return BooleanField()
def select_format(self, compiler, sql, params):
# Wrap filters with a CASE WHEN expression if a database backend
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
# BY list.
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params
def get_db_converters(self, connection):
return self.output_field.get_db_converters(connection)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
class NothingNode:
"""A node that matches nothing."""
contains_aggregate = False
def as_sql(self, compiler=None, connection=None):
@@ -268,7 +239,6 @@ class SubqueryConstraint:
self.alias = alias
self.columns = columns
self.targets = targets
query_object.clear_ordering(clear_default=True)
self.query_object = query_object
def as_sql(self, compiler, connection):
@@ -46,7 +46,7 @@ def create_namedtuple_class(*names):
return unpickle_named_row, (names, tuple(self))
return type(
"Row",
(namedtuple("Row", names),),
{"__reduce__": __reduce__, "__slots__": ()},
'Row',
(namedtuple('Row', names),),
{'__reduce__': __reduce__, '__slots__': ()},
)