测试gitnore
This commit is contained in:
@@ -5,34 +5,14 @@ from django.db.models.aggregates import __all__ as aggregates_all
|
||||
from django.db.models.constraints import * # NOQA
|
||||
from django.db.models.constraints import __all__ as constraints_all
|
||||
from django.db.models.deletion import (
|
||||
CASCADE,
|
||||
DO_NOTHING,
|
||||
PROTECT,
|
||||
RESTRICT,
|
||||
SET,
|
||||
SET_DEFAULT,
|
||||
SET_NULL,
|
||||
ProtectedError,
|
||||
RestrictedError,
|
||||
CASCADE, DO_NOTHING, PROTECT, RESTRICT, SET, SET_DEFAULT, SET_NULL,
|
||||
ProtectedError, RestrictedError,
|
||||
)
|
||||
from django.db.models.enums import * # NOQA
|
||||
from django.db.models.enums import __all__ as enums_all
|
||||
from django.db.models.expressions import (
|
||||
Case,
|
||||
Exists,
|
||||
Expression,
|
||||
ExpressionList,
|
||||
ExpressionWrapper,
|
||||
F,
|
||||
Func,
|
||||
OrderBy,
|
||||
OuterRef,
|
||||
RowRange,
|
||||
Subquery,
|
||||
Value,
|
||||
ValueRange,
|
||||
When,
|
||||
Window,
|
||||
Case, Exists, Expression, ExpressionList, ExpressionWrapper, F, Func,
|
||||
OrderBy, OuterRef, RowRange, Subquery, Value, ValueRange, When, Window,
|
||||
WindowFrame,
|
||||
)
|
||||
from django.db.models.fields import * # NOQA
|
||||
@@ -50,66 +30,23 @@ from django.db.models.query_utils import FilteredRelation, Q
|
||||
# Imports that would create circular imports if sorted
|
||||
from django.db.models.base import DEFERRED, Model # isort:skip
|
||||
from django.db.models.fields.related import ( # isort:skip
|
||||
ForeignKey,
|
||||
ForeignObject,
|
||||
OneToOneField,
|
||||
ManyToManyField,
|
||||
ForeignObjectRel,
|
||||
ManyToOneRel,
|
||||
ManyToManyRel,
|
||||
OneToOneRel,
|
||||
ForeignKey, ForeignObject, OneToOneField, ManyToManyField,
|
||||
ForeignObjectRel, ManyToOneRel, ManyToManyRel, OneToOneRel,
|
||||
)
|
||||
|
||||
|
||||
__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
|
||||
__all__ += [
|
||||
"ObjectDoesNotExist",
|
||||
"signals",
|
||||
"CASCADE",
|
||||
"DO_NOTHING",
|
||||
"PROTECT",
|
||||
"RESTRICT",
|
||||
"SET",
|
||||
"SET_DEFAULT",
|
||||
"SET_NULL",
|
||||
"ProtectedError",
|
||||
"RestrictedError",
|
||||
"Case",
|
||||
"Exists",
|
||||
"Expression",
|
||||
"ExpressionList",
|
||||
"ExpressionWrapper",
|
||||
"F",
|
||||
"Func",
|
||||
"OrderBy",
|
||||
"OuterRef",
|
||||
"RowRange",
|
||||
"Subquery",
|
||||
"Value",
|
||||
"ValueRange",
|
||||
"When",
|
||||
"Window",
|
||||
"WindowFrame",
|
||||
"FileField",
|
||||
"ImageField",
|
||||
"JSONField",
|
||||
"OrderWrt",
|
||||
"Lookup",
|
||||
"Transform",
|
||||
"Manager",
|
||||
"Prefetch",
|
||||
"Q",
|
||||
"QuerySet",
|
||||
"prefetch_related_objects",
|
||||
"DEFERRED",
|
||||
"Model",
|
||||
"FilteredRelation",
|
||||
"ForeignKey",
|
||||
"ForeignObject",
|
||||
"OneToOneField",
|
||||
"ManyToManyField",
|
||||
"ForeignObjectRel",
|
||||
"ManyToOneRel",
|
||||
"ManyToManyRel",
|
||||
"OneToOneRel",
|
||||
'ObjectDoesNotExist', 'signals',
|
||||
'CASCADE', 'DO_NOTHING', 'PROTECT', 'RESTRICT', 'SET', 'SET_DEFAULT',
|
||||
'SET_NULL', 'ProtectedError', 'RestrictedError',
|
||||
'Case', 'Exists', 'Expression', 'ExpressionList', 'ExpressionWrapper', 'F',
|
||||
'Func', 'OrderBy', 'OuterRef', 'RowRange', 'Subquery', 'Value',
|
||||
'ValueRange', 'When',
|
||||
'Window', 'WindowFrame',
|
||||
'FileField', 'ImageField', 'JSONField', 'OrderWrt', 'Lookup', 'Transform',
|
||||
'Manager', 'Prefetch', 'Q', 'QuerySet', 'prefetch_related_objects',
|
||||
'DEFERRED', 'Model', 'FilteredRelation',
|
||||
'ForeignKey', 'ForeignObject', 'OneToOneField', 'ManyToManyField',
|
||||
'ForeignObjectRel', 'ManyToOneRel', 'ManyToManyRel', 'OneToOneRel',
|
||||
]
|
||||
|
||||
@@ -4,43 +4,28 @@ Classes to represent the definitions of aggregate functions.
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.expressions import Case, Func, Star, When
|
||||
from django.db.models.fields import IntegerField
|
||||
from django.db.models.functions.comparison import Coalesce
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDurationInputMixin,
|
||||
NumericOutputFieldMixin,
|
||||
FixDurationInputMixin, NumericOutputFieldMixin,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Aggregate",
|
||||
"Avg",
|
||||
"Count",
|
||||
"Max",
|
||||
"Min",
|
||||
"StdDev",
|
||||
"Sum",
|
||||
"Variance",
|
||||
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
|
||||
]
|
||||
|
||||
|
||||
class Aggregate(Func):
|
||||
template = "%(function)s(%(distinct)s%(expressions)s)"
|
||||
template = '%(function)s(%(distinct)s%(expressions)s)'
|
||||
contains_aggregate = True
|
||||
name = None
|
||||
filter_template = "%s FILTER (WHERE %%(filter)s)"
|
||||
filter_template = '%s FILTER (WHERE %%(filter)s)'
|
||||
window_compatible = True
|
||||
allow_distinct = False
|
||||
empty_result_set_value = None
|
||||
|
||||
def __init__(
|
||||
self, *expressions, distinct=False, filter=None, default=None, **extra
|
||||
):
|
||||
def __init__(self, *expressions, distinct=False, filter=None, **extra):
|
||||
if distinct and not self.allow_distinct:
|
||||
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||
if default is not None and self.empty_result_set_value is not None:
|
||||
raise TypeError(f"{self.__class__.__name__} does not allow default.")
|
||||
self.distinct = distinct
|
||||
self.filter = filter
|
||||
self.default = default
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def get_source_fields(self):
|
||||
@@ -57,14 +42,10 @@ class Aggregate(Func):
|
||||
self.filter = self.filter and exprs.pop()
|
||||
return super().set_source_expressions(exprs)
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
||||
c = super().resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.filter = c.filter and c.filter.resolve_expression(
|
||||
query, allow_joins, reuse, summarize
|
||||
)
|
||||
c.filter = c.filter and c.filter.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if not summarize:
|
||||
# Call Aggregate.get_source_expressions() to avoid
|
||||
# returning self.filter and including that in this loop.
|
||||
@@ -72,48 +53,29 @@ class Aggregate(Func):
|
||||
for index, expr in enumerate(expressions):
|
||||
if expr.contains_aggregate:
|
||||
before_resolved = self.get_source_expressions()[index]
|
||||
name = (
|
||||
before_resolved.name
|
||||
if hasattr(before_resolved, "name")
|
||||
else repr(before_resolved)
|
||||
)
|
||||
raise FieldError(
|
||||
"Cannot compute %s('%s'): '%s' is an aggregate"
|
||||
% (c.name, name, name)
|
||||
)
|
||||
if (default := c.default) is None:
|
||||
return c
|
||||
if hasattr(default, "resolve_expression"):
|
||||
default = default.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.default = None # Reset the default argument before wrapping.
|
||||
coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
|
||||
coalesce.is_summary = c.is_summary
|
||||
return coalesce
|
||||
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
|
||||
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
|
||||
return c
|
||||
|
||||
@property
|
||||
def default_alias(self):
|
||||
expressions = self.get_source_expressions()
|
||||
if len(expressions) == 1 and hasattr(expressions[0], "name"):
|
||||
return "%s__%s" % (expressions[0].name, self.name.lower())
|
||||
if len(expressions) == 1 and hasattr(expressions[0], 'name'):
|
||||
return '%s__%s' % (expressions[0].name, self.name.lower())
|
||||
raise TypeError("Complex expressions require an alias")
|
||||
|
||||
def get_group_by_cols(self, alias=None):
|
||||
return []
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
|
||||
extra_context['distinct'] = 'DISTINCT ' if self.distinct else ''
|
||||
if self.filter:
|
||||
if connection.features.supports_aggregate_filter_clause:
|
||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||
template = self.filter_template % extra_context.get(
|
||||
"template", self.template
|
||||
)
|
||||
template = self.filter_template % extra_context.get('template', self.template)
|
||||
sql, params = super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
filter=filter_sql,
|
||||
**extra_context,
|
||||
compiler, connection, template=template, filter=filter_sql,
|
||||
**extra_context
|
||||
)
|
||||
return sql, params + filter_params
|
||||
else:
|
||||
@@ -122,74 +84,74 @@ class Aggregate(Func):
|
||||
source_expressions = copy.get_source_expressions()
|
||||
condition = When(self.filter, then=source_expressions[0])
|
||||
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
||||
return super(Aggregate, copy).as_sql(
|
||||
compiler, connection, **extra_context
|
||||
)
|
||||
return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def _get_repr_options(self):
|
||||
options = super()._get_repr_options()
|
||||
if self.distinct:
|
||||
options["distinct"] = self.distinct
|
||||
options['distinct'] = self.distinct
|
||||
if self.filter:
|
||||
options["filter"] = self.filter
|
||||
options['filter'] = self.filter
|
||||
return options
|
||||
|
||||
|
||||
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
|
||||
function = "AVG"
|
||||
name = "Avg"
|
||||
function = 'AVG'
|
||||
name = 'Avg'
|
||||
allow_distinct = True
|
||||
|
||||
|
||||
class Count(Aggregate):
|
||||
function = "COUNT"
|
||||
name = "Count"
|
||||
function = 'COUNT'
|
||||
name = 'Count'
|
||||
output_field = IntegerField()
|
||||
allow_distinct = True
|
||||
empty_result_set_value = 0
|
||||
|
||||
def __init__(self, expression, filter=None, **extra):
|
||||
if expression == "*":
|
||||
if expression == '*':
|
||||
expression = Star()
|
||||
if isinstance(expression, Star) and filter is not None:
|
||||
raise ValueError("Star cannot be used with filter. Please specify a field.")
|
||||
raise ValueError('Star cannot be used with filter. Please specify a field.')
|
||||
super().__init__(expression, filter=filter, **extra)
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
return 0 if value is None else value
|
||||
|
||||
|
||||
class Max(Aggregate):
|
||||
function = "MAX"
|
||||
name = "Max"
|
||||
function = 'MAX'
|
||||
name = 'Max'
|
||||
|
||||
|
||||
class Min(Aggregate):
|
||||
function = "MIN"
|
||||
name = "Min"
|
||||
function = 'MIN'
|
||||
name = 'Min'
|
||||
|
||||
|
||||
class StdDev(NumericOutputFieldMixin, Aggregate):
|
||||
name = "StdDev"
|
||||
name = 'StdDev'
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
|
||||
self.function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
|
||||
return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
|
||||
|
||||
|
||||
class Sum(FixDurationInputMixin, Aggregate):
|
||||
function = "SUM"
|
||||
name = "Sum"
|
||||
function = 'SUM'
|
||||
name = 'Sum'
|
||||
allow_distinct = True
|
||||
|
||||
|
||||
class Variance(NumericOutputFieldMixin, Aggregate):
|
||||
name = "Variance"
|
||||
name = 'Variance'
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "VAR_SAMP" if sample else "VAR_POP"
|
||||
self.function = 'VAR_SAMP' if sample else 'VAR_POP'
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}
|
||||
return {**super()._get_repr_options(), 'sample': self.function == 'VAR_SAMP'}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,4 +3,4 @@ Constants used across the ORM in general.
|
||||
"""
|
||||
|
||||
# Separator used to split filter strings apart.
|
||||
LOOKUP_SEP = "__"
|
||||
LOOKUP_SEP = '__'
|
||||
|
||||
@@ -1,34 +1,28 @@
|
||||
from enum import Enum
|
||||
|
||||
from django.db.models.expressions import ExpressionList, F
|
||||
from django.db.models.indexes import IndexExpression
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql.query import Query
|
||||
|
||||
__all__ = ["CheckConstraint", "Deferrable", "UniqueConstraint"]
|
||||
__all__ = ['CheckConstraint', 'Deferrable', 'UniqueConstraint']
|
||||
|
||||
|
||||
class BaseConstraint:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return False
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
raise NotImplementedError('This method must be implemented by a subclass.')
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
raise NotImplementedError('This method must be implemented by a subclass.')
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
raise NotImplementedError('This method must be implemented by a subclass.')
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace("django.db.models.constraints", "django.db.models")
|
||||
return (path, (), {"name": self.name})
|
||||
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace('django.db.models.constraints', 'django.db.models')
|
||||
return (path, (), {'name': self.name})
|
||||
|
||||
def clone(self):
|
||||
_, args, kwargs = self.deconstruct()
|
||||
@@ -38,9 +32,10 @@ class BaseConstraint:
|
||||
class CheckConstraint(BaseConstraint):
|
||||
def __init__(self, *, check, name):
|
||||
self.check = check
|
||||
if not getattr(check, "conditional", False):
|
||||
if not getattr(check, 'conditional', False):
|
||||
raise TypeError(
|
||||
"CheckConstraint.check must be a Q instance or boolean expression."
|
||||
'CheckConstraint.check must be a Q instance or boolean '
|
||||
'expression.'
|
||||
)
|
||||
super().__init__(name)
|
||||
|
||||
@@ -63,11 +58,7 @@ class CheckConstraint(BaseConstraint):
|
||||
return schema_editor._delete_check_sql(model, self.name)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: check=%s name=%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
self.check,
|
||||
repr(self.name),
|
||||
)
|
||||
return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CheckConstraint):
|
||||
@@ -76,84 +67,62 @@ class CheckConstraint(BaseConstraint):
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
kwargs["check"] = self.check
|
||||
kwargs['check'] = self.check
|
||||
return path, args, kwargs
|
||||
|
||||
|
||||
class Deferrable(Enum):
|
||||
DEFERRED = "deferred"
|
||||
IMMEDIATE = "immediate"
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__qualname__}.{self._name_}"
|
||||
DEFERRED = 'deferred'
|
||||
IMMEDIATE = 'immediate'
|
||||
|
||||
|
||||
class UniqueConstraint(BaseConstraint):
|
||||
def __init__(
|
||||
self,
|
||||
*expressions,
|
||||
fields=(),
|
||||
name=None,
|
||||
*,
|
||||
fields,
|
||||
name,
|
||||
condition=None,
|
||||
deferrable=None,
|
||||
include=None,
|
||||
opclasses=(),
|
||||
):
|
||||
if not name:
|
||||
raise ValueError("A unique constraint must be named.")
|
||||
if not expressions and not fields:
|
||||
raise ValueError(
|
||||
"At least one field or expression is required to define a "
|
||||
"unique constraint."
|
||||
)
|
||||
if expressions and fields:
|
||||
raise ValueError(
|
||||
"UniqueConstraint.fields and expressions are mutually exclusive."
|
||||
)
|
||||
if not fields:
|
||||
raise ValueError('At least one field is required to define a unique constraint.')
|
||||
if not isinstance(condition, (type(None), Q)):
|
||||
raise ValueError("UniqueConstraint.condition must be a Q instance.")
|
||||
raise ValueError('UniqueConstraint.condition must be a Q instance.')
|
||||
if condition and deferrable:
|
||||
raise ValueError("UniqueConstraint with conditions cannot be deferred.")
|
||||
if include and deferrable:
|
||||
raise ValueError("UniqueConstraint with include fields cannot be deferred.")
|
||||
if opclasses and deferrable:
|
||||
raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
|
||||
if expressions and deferrable:
|
||||
raise ValueError("UniqueConstraint with expressions cannot be deferred.")
|
||||
if expressions and opclasses:
|
||||
raise ValueError(
|
||||
"UniqueConstraint.opclasses cannot be used with expressions. "
|
||||
"Use django.contrib.postgres.indexes.OpClass() instead."
|
||||
'UniqueConstraint with conditions cannot be deferred.'
|
||||
)
|
||||
if include and deferrable:
|
||||
raise ValueError(
|
||||
'UniqueConstraint with include fields cannot be deferred.'
|
||||
)
|
||||
if opclasses and deferrable:
|
||||
raise ValueError(
|
||||
'UniqueConstraint with opclasses cannot be deferred.'
|
||||
)
|
||||
if not isinstance(deferrable, (type(None), Deferrable)):
|
||||
raise ValueError(
|
||||
"UniqueConstraint.deferrable must be a Deferrable instance."
|
||||
'UniqueConstraint.deferrable must be a Deferrable instance.'
|
||||
)
|
||||
if not isinstance(include, (type(None), list, tuple)):
|
||||
raise ValueError("UniqueConstraint.include must be a list or tuple.")
|
||||
raise ValueError('UniqueConstraint.include must be a list or tuple.')
|
||||
if not isinstance(opclasses, (list, tuple)):
|
||||
raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
|
||||
raise ValueError('UniqueConstraint.opclasses must be a list or tuple.')
|
||||
if opclasses and len(fields) != len(opclasses):
|
||||
raise ValueError(
|
||||
"UniqueConstraint.fields and UniqueConstraint.opclasses must "
|
||||
"have the same number of elements."
|
||||
'UniqueConstraint.fields and UniqueConstraint.opclasses must '
|
||||
'have the same number of elements.'
|
||||
)
|
||||
self.fields = tuple(fields)
|
||||
self.condition = condition
|
||||
self.deferrable = deferrable
|
||||
self.include = tuple(include) if include else ()
|
||||
self.opclasses = opclasses
|
||||
self.expressions = tuple(
|
||||
F(expression) if isinstance(expression, str) else expression
|
||||
for expression in expressions
|
||||
)
|
||||
super().__init__(name)
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return bool(self.expressions)
|
||||
|
||||
def _get_condition_sql(self, model, schema_editor):
|
||||
if self.condition is None:
|
||||
return None
|
||||
@@ -163,105 +132,64 @@ class UniqueConstraint(BaseConstraint):
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def _get_index_expressions(self, model, schema_editor):
|
||||
if not self.expressions:
|
||||
return None
|
||||
index_expressions = []
|
||||
for expression in self.expressions:
|
||||
index_expression = IndexExpression(expression)
|
||||
index_expression.set_wrapper_classes(schema_editor.connection)
|
||||
index_expressions.append(index_expression)
|
||||
return ExpressionList(*index_expressions).resolve_expression(
|
||||
Query(model, alias_cols=False),
|
||||
)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
fields = [model._meta.get_field(field_name).column for field_name in self.fields]
|
||||
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._unique_sql(
|
||||
model,
|
||||
fields,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
model, fields, self.name, condition=condition,
|
||||
deferrable=self.deferrable, include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
fields = [model._meta.get_field(field_name).column for field_name in self.fields]
|
||||
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._create_unique_sql(
|
||||
model,
|
||||
fields,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
model, fields, self.name, condition=condition,
|
||||
deferrable=self.deferrable, include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
||||
return schema_editor._delete_unique_sql(
|
||||
model,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
model, self.name, condition=condition, deferrable=self.deferrable,
|
||||
include=include, opclasses=self.opclasses,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s:%s%s%s%s%s%s%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
"" if not self.fields else " fields=%s" % repr(self.fields),
|
||||
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
|
||||
" name=%s" % repr(self.name),
|
||||
"" if self.condition is None else " condition=%s" % self.condition,
|
||||
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
|
||||
"" if not self.include else " include=%s" % repr(self.include),
|
||||
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
|
||||
return '<%s: fields=%r name=%r%s%s%s%s>' % (
|
||||
self.__class__.__name__, self.fields, self.name,
|
||||
'' if self.condition is None else ' condition=%s' % self.condition,
|
||||
'' if self.deferrable is None else ' deferrable=%s' % self.deferrable,
|
||||
'' if not self.include else ' include=%s' % repr(self.include),
|
||||
'' if not self.opclasses else ' opclasses=%s' % repr(self.opclasses),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, UniqueConstraint):
|
||||
return (
|
||||
self.name == other.name
|
||||
and self.fields == other.fields
|
||||
and self.condition == other.condition
|
||||
and self.deferrable == other.deferrable
|
||||
and self.include == other.include
|
||||
and self.opclasses == other.opclasses
|
||||
and self.expressions == other.expressions
|
||||
self.name == other.name and
|
||||
self.fields == other.fields and
|
||||
self.condition == other.condition and
|
||||
self.deferrable == other.deferrable and
|
||||
self.include == other.include and
|
||||
self.opclasses == other.opclasses
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
if self.fields:
|
||||
kwargs["fields"] = self.fields
|
||||
kwargs['fields'] = self.fields
|
||||
if self.condition:
|
||||
kwargs["condition"] = self.condition
|
||||
kwargs['condition'] = self.condition
|
||||
if self.deferrable:
|
||||
kwargs["deferrable"] = self.deferrable
|
||||
kwargs['deferrable'] = self.deferrable
|
||||
if self.include:
|
||||
kwargs["include"] = self.include
|
||||
kwargs['include'] = self.include
|
||||
if self.opclasses:
|
||||
kwargs["opclasses"] = self.opclasses
|
||||
return path, self.expressions, kwargs
|
||||
kwargs['opclasses'] = self.opclasses
|
||||
return path, args, kwargs
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import operator
|
||||
from collections import Counter, defaultdict
|
||||
from functools import partial
|
||||
from functools import partial, reduce
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
|
||||
@@ -21,11 +22,8 @@ class RestrictedError(IntegrityError):
|
||||
|
||||
def CASCADE(collector, field, sub_objs, using):
|
||||
collector.collect(
|
||||
sub_objs,
|
||||
source=field.remote_field.model,
|
||||
source_attr=field.name,
|
||||
nullable=field.null,
|
||||
fail_on_restricted=False,
|
||||
sub_objs, source=field.remote_field.model, source_attr=field.name,
|
||||
nullable=field.null, fail_on_restricted=False,
|
||||
)
|
||||
if field.null and not connections[using].features.can_defer_constraint_checks:
|
||||
collector.add_field_update(field, None, sub_objs)
|
||||
@@ -34,13 +32,10 @@ def CASCADE(collector, field, sub_objs, using):
|
||||
def PROTECT(collector, field, sub_objs, using):
|
||||
raise ProtectedError(
|
||||
"Cannot delete some instances of model '%s' because they are "
|
||||
"referenced through a protected foreign key: '%s.%s'"
|
||||
% (
|
||||
field.remote_field.model.__name__,
|
||||
sub_objs[0].__class__.__name__,
|
||||
field.name,
|
||||
"referenced through a protected foreign key: '%s.%s'" % (
|
||||
field.remote_field.model.__name__, sub_objs[0].__class__.__name__, field.name
|
||||
),
|
||||
sub_objs,
|
||||
sub_objs
|
||||
)
|
||||
|
||||
|
||||
@@ -51,16 +46,12 @@ def RESTRICT(collector, field, sub_objs, using):
|
||||
|
||||
def SET(value):
|
||||
if callable(value):
|
||||
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value(), sub_objs)
|
||||
|
||||
else:
|
||||
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value, sub_objs)
|
||||
|
||||
set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {})
|
||||
set_on_delete.deconstruct = lambda: ('django.db.models.SET', (value,), {})
|
||||
return set_on_delete
|
||||
|
||||
|
||||
@@ -80,8 +71,7 @@ def get_candidate_relations_to_delete(opts):
|
||||
# The candidate relations are the ones that come from N-1 and 1-1 relations.
|
||||
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
|
||||
return (
|
||||
f
|
||||
for f in opts.get_fields(include_hidden=True)
|
||||
f for f in opts.get_fields(include_hidden=True)
|
||||
if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many)
|
||||
)
|
||||
|
||||
@@ -133,9 +123,7 @@ class Collector:
|
||||
def add_dependency(self, model, dependency, reverse_dependency=False):
|
||||
if reverse_dependency:
|
||||
model, dependency = dependency, model
|
||||
self.dependencies[model._meta.concrete_model].add(
|
||||
dependency._meta.concrete_model
|
||||
)
|
||||
self.dependencies[model._meta.concrete_model].add(dependency._meta.concrete_model)
|
||||
self.data.setdefault(dependency, self.data.default_factory())
|
||||
|
||||
def add_field_update(self, field, value, objs):
|
||||
@@ -162,21 +150,17 @@ class Collector:
|
||||
|
||||
def clear_restricted_objects_from_queryset(self, model, qs):
|
||||
if model in self.restricted_objects:
|
||||
objs = set(
|
||||
qs.filter(
|
||||
pk__in=[
|
||||
obj.pk
|
||||
for objs in self.restricted_objects[model].values()
|
||||
for obj in objs
|
||||
]
|
||||
)
|
||||
)
|
||||
objs = set(qs.filter(pk__in=[
|
||||
obj.pk
|
||||
for objs in self.restricted_objects[model].values() for obj in objs
|
||||
]))
|
||||
self.clear_restricted_objects_from_set(model, objs)
|
||||
|
||||
def _has_signal_listeners(self, model):
|
||||
return signals.pre_delete.has_listeners(
|
||||
model
|
||||
) or signals.post_delete.has_listeners(model)
|
||||
return (
|
||||
signals.pre_delete.has_listeners(model) or
|
||||
signals.post_delete.has_listeners(model)
|
||||
)
|
||||
|
||||
def can_fast_delete(self, objs, from_field=None):
|
||||
"""
|
||||
@@ -191,9 +175,9 @@ class Collector:
|
||||
"""
|
||||
if from_field and from_field.remote_field.on_delete is not CASCADE:
|
||||
return False
|
||||
if hasattr(objs, "_meta"):
|
||||
if hasattr(objs, '_meta'):
|
||||
model = objs._meta.model
|
||||
elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
|
||||
elif hasattr(objs, 'model') and hasattr(objs, '_raw_delete'):
|
||||
model = objs.model
|
||||
else:
|
||||
return False
|
||||
@@ -203,22 +187,14 @@ class Collector:
|
||||
# parent when parent delete is cascading to child.
|
||||
opts = model._meta
|
||||
return (
|
||||
all(
|
||||
link == from_field
|
||||
for link in opts.concrete_model._meta.parents.values()
|
||||
)
|
||||
and
|
||||
all(link == from_field for link in opts.concrete_model._meta.parents.values()) and
|
||||
# Foreign keys pointing to this model.
|
||||
all(
|
||||
related.field.remote_field.on_delete is DO_NOTHING
|
||||
for related in get_candidate_relations_to_delete(opts)
|
||||
)
|
||||
and (
|
||||
) and (
|
||||
# Something like generic foreign key.
|
||||
not any(
|
||||
hasattr(field, "bulk_related_objects")
|
||||
for field in opts.private_fields
|
||||
)
|
||||
not any(hasattr(field, 'bulk_related_objects') for field in opts.private_fields)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -228,27 +204,16 @@ class Collector:
|
||||
"""
|
||||
field_names = [field.name for field in fields]
|
||||
conn_batch_size = max(
|
||||
connections[self.using].ops.bulk_batch_size(field_names, objs), 1
|
||||
)
|
||||
connections[self.using].ops.bulk_batch_size(field_names, objs), 1)
|
||||
if len(objs) > conn_batch_size:
|
||||
return [
|
||||
objs[i : i + conn_batch_size]
|
||||
for i in range(0, len(objs), conn_batch_size)
|
||||
]
|
||||
return [objs[i:i + conn_batch_size]
|
||||
for i in range(0, len(objs), conn_batch_size)]
|
||||
else:
|
||||
return [objs]
|
||||
|
||||
def collect(
|
||||
self,
|
||||
objs,
|
||||
source=None,
|
||||
nullable=False,
|
||||
collect_related=True,
|
||||
source_attr=None,
|
||||
reverse_dependency=False,
|
||||
keep_parents=False,
|
||||
fail_on_restricted=True,
|
||||
):
|
||||
def collect(self, objs, source=None, nullable=False, collect_related=True,
|
||||
source_attr=None, reverse_dependency=False, keep_parents=False,
|
||||
fail_on_restricted=True):
|
||||
"""
|
||||
Add 'objs' to the collection of objects to be deleted as well as all
|
||||
parent instances. 'objs' must be a homogeneous iterable collection of
|
||||
@@ -275,9 +240,8 @@ class Collector:
|
||||
if self.can_fast_delete(objs):
|
||||
self.fast_deletes.append(objs)
|
||||
return
|
||||
new_objs = self.add(
|
||||
objs, source, nullable, reverse_dependency=reverse_dependency
|
||||
)
|
||||
new_objs = self.add(objs, source, nullable,
|
||||
reverse_dependency=reverse_dependency)
|
||||
if not new_objs:
|
||||
return
|
||||
|
||||
@@ -290,14 +254,11 @@ class Collector:
|
||||
for ptr in concrete_model._meta.parents.values():
|
||||
if ptr:
|
||||
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
|
||||
self.collect(
|
||||
parent_objs,
|
||||
source=model,
|
||||
source_attr=ptr.remote_field.related_name,
|
||||
collect_related=False,
|
||||
reverse_dependency=True,
|
||||
fail_on_restricted=False,
|
||||
)
|
||||
self.collect(parent_objs, source=model,
|
||||
source_attr=ptr.remote_field.related_name,
|
||||
collect_related=False,
|
||||
reverse_dependency=True,
|
||||
fail_on_restricted=False)
|
||||
if not collect_related:
|
||||
return
|
||||
|
||||
@@ -325,18 +286,11 @@ class Collector:
|
||||
# relationships are select_related as interactions between both
|
||||
# features are hard to get right. This should only happen in
|
||||
# the rare cases where .related_objects is overridden anyway.
|
||||
if not (
|
||||
sub_objs.query.select_related
|
||||
or self._has_signal_listeners(related_model)
|
||||
):
|
||||
referenced_fields = set(
|
||||
chain.from_iterable(
|
||||
(rf.attname for rf in rel.field.foreign_related_fields)
|
||||
for rel in get_candidate_relations_to_delete(
|
||||
related_model._meta
|
||||
)
|
||||
)
|
||||
)
|
||||
if not (sub_objs.query.select_related or self._has_signal_listeners(related_model)):
|
||||
referenced_fields = set(chain.from_iterable(
|
||||
(rf.attname for rf in rel.field.foreign_related_fields)
|
||||
for rel in get_candidate_relations_to_delete(related_model._meta)
|
||||
))
|
||||
sub_objs = sub_objs.only(*tuple(referenced_fields))
|
||||
if sub_objs:
|
||||
try:
|
||||
@@ -346,11 +300,10 @@ class Collector:
|
||||
protected_objects[key] += error.protected_objects
|
||||
if protected_objects:
|
||||
raise ProtectedError(
|
||||
"Cannot delete some instances of model %r because they are "
|
||||
"referenced through protected foreign keys: %s."
|
||||
% (
|
||||
'Cannot delete some instances of model %r because they are '
|
||||
'referenced through protected foreign keys: %s.' % (
|
||||
model.__name__,
|
||||
", ".join(protected_objects),
|
||||
', '.join(protected_objects),
|
||||
),
|
||||
set(chain.from_iterable(protected_objects.values())),
|
||||
)
|
||||
@@ -360,12 +313,10 @@ class Collector:
|
||||
sub_objs = self.related_objects(related_model, related_fields, batch)
|
||||
self.fast_deletes.append(sub_objs)
|
||||
for field in model._meta.private_fields:
|
||||
if hasattr(field, "bulk_related_objects"):
|
||||
if hasattr(field, 'bulk_related_objects'):
|
||||
# It's something like generic foreign key.
|
||||
sub_objs = field.bulk_related_objects(new_objs, self.using)
|
||||
self.collect(
|
||||
sub_objs, source=model, nullable=True, fail_on_restricted=False
|
||||
)
|
||||
self.collect(sub_objs, source=model, nullable=True, fail_on_restricted=False)
|
||||
|
||||
if fail_on_restricted:
|
||||
# Raise an error if collected restricted objects (RESTRICT) aren't
|
||||
@@ -383,12 +334,11 @@ class Collector:
|
||||
restricted_objects[key] += objs
|
||||
if restricted_objects:
|
||||
raise RestrictedError(
|
||||
"Cannot delete some instances of model %r because "
|
||||
"they are referenced through restricted foreign keys: "
|
||||
"%s."
|
||||
% (
|
||||
'Cannot delete some instances of model %r because '
|
||||
'they are referenced through restricted foreign keys: '
|
||||
'%s.' % (
|
||||
model.__name__,
|
||||
", ".join(restricted_objects),
|
||||
', '.join(restricted_objects),
|
||||
),
|
||||
set(chain.from_iterable(restricted_objects.values())),
|
||||
)
|
||||
@@ -397,10 +347,10 @@ class Collector:
|
||||
"""
|
||||
Get a QuerySet of the related model to objs via related fields.
|
||||
"""
|
||||
predicate = query_utils.Q(
|
||||
*((f"{related_field.name}__in", objs) for related_field in related_fields),
|
||||
_connector=query_utils.Q.OR,
|
||||
)
|
||||
predicate = reduce(operator.or_, (
|
||||
query_utils.Q(**{'%s__in' % related_field.name: objs})
|
||||
for related_field in related_fields
|
||||
))
|
||||
return related_model._base_manager.using(self.using).filter(predicate)
|
||||
|
||||
def instances_with_model(self):
|
||||
@@ -443,9 +393,7 @@ class Collector:
|
||||
instance = list(instances)[0]
|
||||
if self.can_fast_delete(instance):
|
||||
with transaction.mark_for_rollback_on_error(self.using):
|
||||
count = sql.DeleteQuery(model).delete_batch(
|
||||
[instance.pk], self.using
|
||||
)
|
||||
count = sql.DeleteQuery(model).delete_batch([instance.pk], self.using)
|
||||
setattr(instance, model._meta.pk.attname, None)
|
||||
return count, {model._meta.label: count}
|
||||
|
||||
@@ -467,9 +415,8 @@ class Collector:
|
||||
for model, instances_for_fieldvalues in self.field_updates.items():
|
||||
for (field, value), instances in instances_for_fieldvalues.items():
|
||||
query = sql.UpdateQuery(model)
|
||||
query.update_batch(
|
||||
[obj.pk for obj in instances], {field.name: value}, self.using
|
||||
)
|
||||
query.update_batch([obj.pk for obj in instances],
|
||||
{field.name: value}, self.using)
|
||||
|
||||
# reverse instance collections
|
||||
for instances in self.data.values():
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import enum
|
||||
from types import DynamicClassAttribute
|
||||
|
||||
from django.utils.functional import Promise
|
||||
|
||||
__all__ = ["Choices", "IntegerChoices", "TextChoices"]
|
||||
__all__ = ['Choices', 'IntegerChoices', 'TextChoices']
|
||||
|
||||
|
||||
class ChoicesMeta(enum.EnumMeta):
|
||||
@@ -14,21 +13,25 @@ class ChoicesMeta(enum.EnumMeta):
|
||||
for key in classdict._member_names:
|
||||
value = classdict[key]
|
||||
if (
|
||||
isinstance(value, (list, tuple))
|
||||
and len(value) > 1
|
||||
and isinstance(value[-1], (Promise, str))
|
||||
isinstance(value, (list, tuple)) and
|
||||
len(value) > 1 and
|
||||
isinstance(value[-1], (Promise, str))
|
||||
):
|
||||
*value, label = value
|
||||
value = tuple(value)
|
||||
else:
|
||||
label = key.replace("_", " ").title()
|
||||
label = key.replace('_', ' ').title()
|
||||
labels.append(label)
|
||||
# Use dict.__setitem__() to suppress defenses against double
|
||||
# assignment in enum's classdict.
|
||||
dict.__setitem__(classdict, key, value)
|
||||
cls = super().__new__(metacls, classname, bases, classdict, **kwds)
|
||||
for member, label in zip(cls.__members__.values(), labels):
|
||||
member._label_ = label
|
||||
cls._value2label_map_ = dict(zip(cls._value2member_map_, labels))
|
||||
# Add a label property to instances of enum which uses the enum member
|
||||
# that is passed in as "self" as the value to use when looking up the
|
||||
# label in the choices.
|
||||
cls.label = property(lambda self: cls._value2label_map_.get(self.value))
|
||||
cls.do_not_call_in_templates = True
|
||||
return enum.unique(cls)
|
||||
|
||||
def __contains__(cls, member):
|
||||
@@ -39,12 +42,12 @@ class ChoicesMeta(enum.EnumMeta):
|
||||
|
||||
@property
|
||||
def names(cls):
|
||||
empty = ["__empty__"] if hasattr(cls, "__empty__") else []
|
||||
empty = ['__empty__'] if hasattr(cls, '__empty__') else []
|
||||
return empty + [member.name for member in cls]
|
||||
|
||||
@property
|
||||
def choices(cls):
|
||||
empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else []
|
||||
empty = [(None, cls.__empty__)] if hasattr(cls, '__empty__') else []
|
||||
return empty + [(member.value, member.label) for member in cls]
|
||||
|
||||
@property
|
||||
@@ -59,14 +62,6 @@ class ChoicesMeta(enum.EnumMeta):
|
||||
class Choices(enum.Enum, metaclass=ChoicesMeta):
|
||||
"""Class for creating enumerated choices."""
|
||||
|
||||
@DynamicClassAttribute
|
||||
def label(self):
|
||||
return self._label_
|
||||
|
||||
@property
|
||||
def do_not_call_in_templates(self):
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Use value when cast to str, so that Choices set as model instance
|
||||
@@ -74,14 +69,9 @@ class Choices(enum.Enum, metaclass=ChoicesMeta):
|
||||
"""
|
||||
return str(self.value)
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__qualname__}.{self._name_}"
|
||||
|
||||
|
||||
class IntegerChoices(int, Choices):
|
||||
"""Class for creating enumerated integer choices."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@ class FieldFile(File):
|
||||
def __eq__(self, other):
|
||||
# Older code may be expecting FileField values to be simple strings.
|
||||
# By overriding the == operator, it can remain backwards compatibility.
|
||||
if hasattr(other, "name"):
|
||||
if hasattr(other, 'name'):
|
||||
return self.name == other.name
|
||||
return self.name == other
|
||||
|
||||
@@ -37,14 +37,12 @@ class FieldFile(File):
|
||||
|
||||
def _require_file(self):
|
||||
if not self:
|
||||
raise ValueError(
|
||||
"The '%s' attribute has no file associated with it." % self.field.name
|
||||
)
|
||||
raise ValueError("The '%s' attribute has no file associated with it." % self.field.name)
|
||||
|
||||
def _get_file(self):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self._file = self.storage.open(self.name, "rb")
|
||||
if getattr(self, '_file', None) is None:
|
||||
self._file = self.storage.open(self.name, 'rb')
|
||||
return self._file
|
||||
|
||||
def _set_file(self, file):
|
||||
@@ -72,14 +70,13 @@ class FieldFile(File):
|
||||
return self.file.size
|
||||
return self.storage.size(self.name)
|
||||
|
||||
def open(self, mode="rb"):
|
||||
def open(self, mode='rb'):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
if getattr(self, '_file', None) is None:
|
||||
self.file = self.storage.open(self.name, mode)
|
||||
else:
|
||||
self.file.open(mode)
|
||||
return self
|
||||
|
||||
# open() doesn't alter the file's contents, but it does reset the pointer
|
||||
open.alters_data = True
|
||||
|
||||
@@ -96,7 +93,6 @@ class FieldFile(File):
|
||||
# Save the object because it has changed, unless save is False
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
save.alters_data = True
|
||||
|
||||
def delete(self, save=True):
|
||||
@@ -104,7 +100,7 @@ class FieldFile(File):
|
||||
return
|
||||
# Only close the file if it's already open, which we know by the
|
||||
# presence of self._file
|
||||
if hasattr(self, "_file"):
|
||||
if hasattr(self, '_file'):
|
||||
self.close()
|
||||
del self.file
|
||||
|
||||
@@ -116,16 +112,15 @@ class FieldFile(File):
|
||||
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
delete.alters_data = True
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
file = getattr(self, "_file", None)
|
||||
file = getattr(self, '_file', None)
|
||||
return file is None or file.closed
|
||||
|
||||
def close(self):
|
||||
file = getattr(self, "_file", None)
|
||||
file = getattr(self, '_file', None)
|
||||
if file is not None:
|
||||
file.close()
|
||||
|
||||
@@ -134,12 +129,12 @@ class FieldFile(File):
|
||||
# the file's name. Everything else will be restored later, by
|
||||
# FileDescriptor below.
|
||||
return {
|
||||
"name": self.name,
|
||||
"closed": False,
|
||||
"_committed": True,
|
||||
"_file": None,
|
||||
"instance": self.instance,
|
||||
"field": self.field,
|
||||
'name': self.name,
|
||||
'closed': False,
|
||||
'_committed': True,
|
||||
'_file': None,
|
||||
'instance': self.instance,
|
||||
'field': self.field,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -161,7 +156,6 @@ class FileDescriptor(DeferredAttribute):
|
||||
>>> with open('/path/to/hello.world') as f:
|
||||
... instance.file = File(f)
|
||||
"""
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
@@ -204,7 +198,7 @@ class FileDescriptor(DeferredAttribute):
|
||||
# Finally, because of the (some would say boneheaded) way pickle works,
|
||||
# the underlying FieldFile might not actually itself have an associated
|
||||
# file. So we need to reset the details of the FieldFile in those cases.
|
||||
elif isinstance(file, FieldFile) and not hasattr(file, "field"):
|
||||
elif isinstance(file, FieldFile) and not hasattr(file, 'field'):
|
||||
file.instance = instance
|
||||
file.field = self.field
|
||||
file.storage = self.field.storage
|
||||
@@ -231,10 +225,8 @@ class FileField(Field):
|
||||
|
||||
description = _("File")
|
||||
|
||||
def __init__(
|
||||
self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
|
||||
):
|
||||
self._primary_key_set_explicitly = "primary_key" in kwargs
|
||||
def __init__(self, verbose_name=None, name=None, upload_to='', storage=None, **kwargs):
|
||||
self._primary_key_set_explicitly = 'primary_key' in kwargs
|
||||
|
||||
self.storage = storage or default_storage
|
||||
if callable(self.storage):
|
||||
@@ -244,15 +236,11 @@ class FileField(Field):
|
||||
if not isinstance(self.storage, Storage):
|
||||
raise TypeError(
|
||||
"%s.storage must be a subclass/instance of %s.%s"
|
||||
% (
|
||||
self.__class__.__qualname__,
|
||||
Storage.__module__,
|
||||
Storage.__qualname__,
|
||||
)
|
||||
% (self.__class__.__qualname__, Storage.__module__, Storage.__qualname__)
|
||||
)
|
||||
self.upload_to = upload_to
|
||||
|
||||
kwargs.setdefault("max_length", 100)
|
||||
kwargs.setdefault('max_length', 100)
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
@@ -266,24 +254,23 @@ class FileField(Field):
|
||||
if self._primary_key_set_explicitly:
|
||||
return [
|
||||
checks.Error(
|
||||
"'primary_key' is not a valid argument for a %s."
|
||||
% self.__class__.__name__,
|
||||
"'primary_key' is not a valid argument for a %s." % self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E201",
|
||||
id='fields.E201',
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _check_upload_to(self):
|
||||
if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
|
||||
if isinstance(self.upload_to, str) and self.upload_to.startswith('/'):
|
||||
return [
|
||||
checks.Error(
|
||||
"%s's 'upload_to' argument must be a relative path, not an "
|
||||
"absolute path." % self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E202",
|
||||
hint="Remove the leading slash.",
|
||||
id='fields.E202',
|
||||
hint='Remove the leading slash.',
|
||||
)
|
||||
]
|
||||
else:
|
||||
@@ -293,9 +280,9 @@ class FileField(Field):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if kwargs.get("max_length") == 100:
|
||||
del kwargs["max_length"]
|
||||
kwargs["upload_to"] = self.upload_to
|
||||
kwargs['upload_to'] = self.upload_to
|
||||
if self.storage is not default_storage:
|
||||
kwargs["storage"] = getattr(self, "_storage_callable", self.storage)
|
||||
kwargs['storage'] = getattr(self, '_storage_callable', self.storage)
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
@@ -303,8 +290,7 @@ class FileField(Field):
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
# Need to convert File objects provided via a form to string for
|
||||
# database insertion.
|
||||
# Need to convert File objects provided via a form to string for database insertion
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
@@ -343,16 +329,14 @@ class FileField(Field):
|
||||
if data is not None:
|
||||
# This value will be converted to str and stored in the
|
||||
# database, so leaving False as-is is not acceptable.
|
||||
setattr(instance, self.name, data or "")
|
||||
setattr(instance, self.name, data or '')
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.FileField,
|
||||
"max_length": self.max_length,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
return super().formfield(**{
|
||||
'form_class': forms.FileField,
|
||||
'max_length': self.max_length,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
|
||||
class ImageFileDescriptor(FileDescriptor):
|
||||
@@ -360,7 +344,6 @@ class ImageFileDescriptor(FileDescriptor):
|
||||
Just like the FileDescriptor, but for ImageFields. The only difference is
|
||||
assigning the width/height to the width_field/height_field, if appropriate.
|
||||
"""
|
||||
|
||||
def __set__(self, instance, value):
|
||||
previous_file = instance.__dict__.get(self.field.attname)
|
||||
super().__set__(instance, value)
|
||||
@@ -381,7 +364,7 @@ class ImageFileDescriptor(FileDescriptor):
|
||||
class ImageFieldFile(ImageFile, FieldFile):
|
||||
def delete(self, save=True):
|
||||
# Clear the image dimensions cache
|
||||
if hasattr(self, "_dimensions_cache"):
|
||||
if hasattr(self, '_dimensions_cache'):
|
||||
del self._dimensions_cache
|
||||
super().delete(save)
|
||||
|
||||
@@ -391,14 +374,7 @@ class ImageField(FileField):
|
||||
descriptor_class = ImageFileDescriptor
|
||||
description = _("Image")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
width_field=None,
|
||||
height_field=None,
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs):
|
||||
self.width_field, self.height_field = width_field, height_field
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
@@ -414,13 +390,11 @@ class ImageField(FileField):
|
||||
except ImportError:
|
||||
return [
|
||||
checks.Error(
|
||||
"Cannot use ImageField because Pillow is not installed.",
|
||||
hint=(
|
||||
"Get Pillow at https://pypi.org/project/Pillow/ "
|
||||
'or run command "python -m pip install Pillow".'
|
||||
),
|
||||
'Cannot use ImageField because Pillow is not installed.',
|
||||
hint=('Get Pillow at https://pypi.org/project/Pillow/ '
|
||||
'or run command "python -m pip install Pillow".'),
|
||||
obj=self,
|
||||
id="fields.E210",
|
||||
id='fields.E210',
|
||||
)
|
||||
]
|
||||
else:
|
||||
@@ -429,9 +403,9 @@ class ImageField(FileField):
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.width_field:
|
||||
kwargs["width_field"] = self.width_field
|
||||
kwargs['width_field'] = self.width_field
|
||||
if self.height_field:
|
||||
kwargs["height_field"] = self.height_field
|
||||
kwargs['height_field'] = self.height_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
@@ -471,9 +445,9 @@ class ImageField(FileField):
|
||||
if not file and not force:
|
||||
return
|
||||
|
||||
dimension_fields_filled = not (
|
||||
(self.width_field and not getattr(instance, self.width_field))
|
||||
or (self.height_field and not getattr(instance, self.height_field))
|
||||
dimension_fields_filled = not(
|
||||
(self.width_field and not getattr(instance, self.width_field)) or
|
||||
(self.height_field and not getattr(instance, self.height_field))
|
||||
)
|
||||
# When both dimension fields have values, we are most likely loading
|
||||
# data from the database or updating an image field that already had
|
||||
@@ -501,9 +475,7 @@ class ImageField(FileField):
|
||||
setattr(instance, self.height_field, height)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.ImageField,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
return super().formfield(**{
|
||||
'form_class': forms.ImageField,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
@@ -10,36 +10,32 @@ from django.utils.translation import gettext_lazy as _
|
||||
from . import Field
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
__all__ = ["JSONField"]
|
||||
__all__ = ['JSONField']
|
||||
|
||||
|
||||
class JSONField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _("A JSON object")
|
||||
description = _('A JSON object')
|
||||
default_error_messages = {
|
||||
"invalid": _("Value must be valid JSON."),
|
||||
'invalid': _('Value must be valid JSON.'),
|
||||
}
|
||||
_default_hint = ("dict", "{}")
|
||||
_default_hint = ('dict', '{}')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
self, verbose_name=None, name=None, encoder=None, decoder=None,
|
||||
**kwargs,
|
||||
):
|
||||
if encoder and not callable(encoder):
|
||||
raise ValueError("The encoder parameter must be a callable object.")
|
||||
raise ValueError('The encoder parameter must be a callable object.')
|
||||
if decoder and not callable(decoder):
|
||||
raise ValueError("The decoder parameter must be a callable object.")
|
||||
raise ValueError('The decoder parameter must be a callable object.')
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
databases = kwargs.get("databases") or []
|
||||
databases = kwargs.get('databases') or []
|
||||
errors.extend(self._check_supported(databases))
|
||||
return errors
|
||||
|
||||
@@ -50,19 +46,20 @@ class JSONField(CheckFieldDefaultMixin, Field):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
self.model._meta.required_db_vendor and
|
||||
self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
"supports_json_field" in self.model._meta.required_db_features
|
||||
or connection.features.supports_json_field
|
||||
'supports_json_field' in self.model._meta.required_db_features or
|
||||
connection.features.supports_json_field
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"%s does not support JSONFields." % connection.display_name,
|
||||
'%s does not support JSONFields.'
|
||||
% connection.display_name,
|
||||
obj=self.model,
|
||||
id="fields.E180",
|
||||
id='fields.E180',
|
||||
)
|
||||
)
|
||||
return errors
|
||||
@@ -70,9 +67,9 @@ class JSONField(CheckFieldDefaultMixin, Field):
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.encoder is not None:
|
||||
kwargs["encoder"] = self.encoder
|
||||
kwargs['encoder'] = self.encoder
|
||||
if self.decoder is not None:
|
||||
kwargs["decoder"] = self.decoder
|
||||
kwargs['decoder'] = self.decoder
|
||||
return name, path, args, kwargs
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
@@ -88,7 +85,7 @@ class JSONField(CheckFieldDefaultMixin, Field):
|
||||
return value
|
||||
|
||||
def get_internal_type(self):
|
||||
return "JSONField"
|
||||
return 'JSONField'
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if value is None:
|
||||
@@ -107,66 +104,64 @@ class JSONField(CheckFieldDefaultMixin, Field):
|
||||
json.dumps(value, cls=self.encoder)
|
||||
except TypeError:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["invalid"],
|
||||
code="invalid",
|
||||
params={"value": value},
|
||||
self.error_messages['invalid'],
|
||||
code='invalid',
|
||||
params={'value': value},
|
||||
)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return self.value_from_object(obj)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.JSONField,
|
||||
"encoder": self.encoder,
|
||||
"decoder": self.decoder,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
return super().formfield(**{
|
||||
'form_class': forms.JSONField,
|
||||
'encoder': self.encoder,
|
||||
'decoder': self.decoder,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
|
||||
def compile_json_path(key_transforms, include_root=True):
|
||||
path = ["$"] if include_root else []
|
||||
path = ['$'] if include_root else []
|
||||
for key_transform in key_transforms:
|
||||
try:
|
||||
num = int(key_transform)
|
||||
except ValueError: # non-integer
|
||||
path.append(".")
|
||||
path.append('.')
|
||||
path.append(json.dumps(key_transform))
|
||||
else:
|
||||
path.append("[%s]" % num)
|
||||
return "".join(path)
|
||||
path.append('[%s]' % num)
|
||||
return ''.join(path)
|
||||
|
||||
|
||||
class DataContains(PostgresOperatorLookup):
|
||||
lookup_name = "contains"
|
||||
postgres_operator = "@>"
|
||||
lookup_name = 'contains'
|
||||
postgres_operator = '@>'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contains lookup is not supported on this database backend."
|
||||
'contains lookup is not supported on this database backend.'
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
|
||||
return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params
|
||||
|
||||
|
||||
class ContainedBy(PostgresOperatorLookup):
|
||||
lookup_name = "contained_by"
|
||||
postgres_operator = "<@"
|
||||
lookup_name = 'contained_by'
|
||||
postgres_operator = '<@'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contained_by lookup is not supported on this database backend."
|
||||
'contained_by lookup is not supported on this database backend.'
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(rhs_params) + tuple(lhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
|
||||
return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params
|
||||
|
||||
|
||||
class HasKeyLookup(PostgresOperatorLookup):
|
||||
@@ -175,13 +170,11 @@ class HasKeyLookup(PostgresOperatorLookup):
|
||||
def as_sql(self, compiler, connection, template=None):
|
||||
# Process JSON path from the left-hand side.
|
||||
if isinstance(self.lhs, KeyTransform):
|
||||
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
|
||||
compiler, connection
|
||||
)
|
||||
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection)
|
||||
lhs_json_path = compile_json_path(lhs_key_transforms)
|
||||
else:
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
lhs_json_path = "$"
|
||||
lhs_json_path = '$'
|
||||
sql = template % lhs
|
||||
# Process JSON path from the right-hand side.
|
||||
rhs = self.rhs
|
||||
@@ -193,27 +186,20 @@ class HasKeyLookup(PostgresOperatorLookup):
|
||||
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
|
||||
else:
|
||||
rhs_key_transforms = [key]
|
||||
rhs_params.append(
|
||||
"%s%s"
|
||||
% (
|
||||
lhs_json_path,
|
||||
compile_json_path(rhs_key_transforms, include_root=False),
|
||||
)
|
||||
)
|
||||
rhs_params.append('%s%s' % (
|
||||
lhs_json_path,
|
||||
compile_json_path(rhs_key_transforms, include_root=False),
|
||||
))
|
||||
# Add condition for each key.
|
||||
if self.logical_operator:
|
||||
sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
|
||||
sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params))
|
||||
return sql, tuple(lhs_params) + tuple(rhs_params)
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
|
||||
)
|
||||
return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)")
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = self.as_sql(
|
||||
compiler, connection, template="JSON_EXISTS(%s, '%%s')"
|
||||
)
|
||||
sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')")
|
||||
# Add paths directly into SQL because path expressions cannot be passed
|
||||
# as bind variables on Oracle.
|
||||
return sql % tuple(params), []
|
||||
@@ -227,83 +213,64 @@ class HasKeyLookup(PostgresOperatorLookup):
|
||||
return super().as_postgresql(compiler, connection)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
|
||||
)
|
||||
return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL')
|
||||
|
||||
|
||||
class HasKey(HasKeyLookup):
|
||||
lookup_name = "has_key"
|
||||
postgres_operator = "?"
|
||||
lookup_name = 'has_key'
|
||||
postgres_operator = '?'
|
||||
prepare_rhs = False
|
||||
|
||||
|
||||
class HasKeys(HasKeyLookup):
|
||||
lookup_name = "has_keys"
|
||||
postgres_operator = "?&"
|
||||
logical_operator = " AND "
|
||||
lookup_name = 'has_keys'
|
||||
postgres_operator = '?&'
|
||||
logical_operator = ' AND '
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return [str(item) for item in self.rhs]
|
||||
|
||||
|
||||
class HasAnyKeys(HasKeys):
|
||||
lookup_name = "has_any_keys"
|
||||
postgres_operator = "?|"
|
||||
logical_operator = " OR "
|
||||
|
||||
|
||||
class CaseInsensitiveMixin:
|
||||
"""
|
||||
Mixin to allow case-insensitive comparison of JSON values on MySQL.
|
||||
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
|
||||
Because utf8mb4_bin is a binary collation, comparison of JSON values is
|
||||
case-sensitive.
|
||||
"""
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % lhs, lhs_params
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % rhs, rhs_params
|
||||
return rhs, rhs_params
|
||||
lookup_name = 'has_any_keys'
|
||||
postgres_operator = '?|'
|
||||
logical_operator = ' OR '
|
||||
|
||||
|
||||
class JSONExact(lookups.Exact):
|
||||
can_use_none_as_rhs = True
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == 'sqlite':
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if rhs == '%s' and rhs_params == [None]:
|
||||
# Use JSON_TYPE instead of JSON_EXTRACT for NULLs.
|
||||
lhs = "JSON_TYPE(%s, '$')" % lhs
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
# Treat None lookup values as null.
|
||||
if rhs == "%s" and rhs_params == [None]:
|
||||
rhs_params = ["null"]
|
||||
if connection.vendor == "mysql":
|
||||
if rhs == '%s' and rhs_params == [None]:
|
||||
rhs_params = ['null']
|
||||
if connection.vendor == 'mysql':
|
||||
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
|
||||
rhs = rhs % tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
JSONField.register_lookup(DataContains)
|
||||
JSONField.register_lookup(ContainedBy)
|
||||
JSONField.register_lookup(HasKey)
|
||||
JSONField.register_lookup(HasKeys)
|
||||
JSONField.register_lookup(HasAnyKeys)
|
||||
JSONField.register_lookup(JSONExact)
|
||||
JSONField.register_lookup(JSONIContains)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
postgres_operator = "->"
|
||||
postgres_nested_operator = "#>"
|
||||
postgres_operator = '->'
|
||||
postgres_nested_operator = '#>'
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -316,50 +283,44 @@ class KeyTransform(Transform):
|
||||
key_transforms.insert(0, previous.key_name)
|
||||
previous = previous.lhs
|
||||
lhs, params = compiler.compile(previous)
|
||||
if connection.vendor == "oracle":
|
||||
if connection.vendor == 'oracle':
|
||||
# Escape string-formatting.
|
||||
key_transforms = [key.replace("%", "%%") for key in key_transforms]
|
||||
key_transforms = [key.replace('%', '%%') for key in key_transforms]
|
||||
return lhs, params, key_transforms
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
|
||||
return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return (
|
||||
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
|
||||
% ((lhs, json_path) * 2)
|
||||
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" %
|
||||
((lhs, json_path) * 2)
|
||||
), tuple(params) * 2
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
if len(key_transforms) > 1:
|
||||
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
|
||||
sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator)
|
||||
return sql, tuple(params) + (key_transforms,)
|
||||
try:
|
||||
lookup = int(self.key_name)
|
||||
except ValueError:
|
||||
lookup = self.key_name
|
||||
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
|
||||
return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
datatype_values = ",".join(
|
||||
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
|
||||
)
|
||||
return (
|
||||
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||||
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||||
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
|
||||
return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
|
||||
|
||||
|
||||
class KeyTextTransform(KeyTransform):
|
||||
postgres_operator = "->>"
|
||||
postgres_nested_operator = "#>>"
|
||||
postgres_operator = '->>'
|
||||
postgres_nested_operator = '#>>'
|
||||
|
||||
|
||||
class KeyTransformTextLookupMixin:
|
||||
@@ -369,21 +330,39 @@ class KeyTransformTextLookupMixin:
|
||||
key values to text and performing the lookup on the resulting
|
||||
representation.
|
||||
"""
|
||||
|
||||
def __init__(self, key_transform, *args, **kwargs):
|
||||
if not isinstance(key_transform, KeyTransform):
|
||||
raise TypeError(
|
||||
"Transform should be an instance of KeyTransform in order to "
|
||||
"use this lookup."
|
||||
'Transform should be an instance of KeyTransform in order to '
|
||||
'use this lookup.'
|
||||
)
|
||||
key_text_transform = KeyTextTransform(
|
||||
key_transform.key_name,
|
||||
*key_transform.source_expressions,
|
||||
key_transform.key_name, *key_transform.source_expressions,
|
||||
**key_transform.extra,
|
||||
)
|
||||
super().__init__(key_text_transform, *args, **kwargs)
|
||||
|
||||
|
||||
class CaseInsensitiveMixin:
|
||||
"""
|
||||
Mixin to allow case-insensitive comparison of JSON values on MySQL.
|
||||
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
|
||||
Because utf8mb4_bin is a binary collation, comparison of JSON values is
|
||||
case-sensitive.
|
||||
"""
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == 'mysql':
|
||||
return 'LOWER(%s)' % lhs, lhs_params
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == 'mysql':
|
||||
return 'LOWER(%s)' % rhs, rhs_params
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class KeyTransformIsNull(lookups.IsNull):
|
||||
# key__isnull=False is the same as has_key='key'
|
||||
def as_oracle(self, compiler, connection):
|
||||
@@ -395,12 +374,12 @@ class KeyTransformIsNull(lookups.IsNull):
|
||||
return sql, params
|
||||
# Column doesn't have a key or IS NULL.
|
||||
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
|
||||
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
|
||||
return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
template = "JSON_TYPE(%s, %%s) IS NULL"
|
||||
template = 'JSON_TYPE(%s, %%s) IS NULL'
|
||||
if not self.rhs:
|
||||
template = "JSON_TYPE(%s, %%s) IS NOT NULL"
|
||||
template = 'JSON_TYPE(%s, %%s) IS NOT NULL'
|
||||
return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
@@ -411,81 +390,75 @@ class KeyTransformIsNull(lookups.IsNull):
|
||||
class KeyTransformIn(lookups.In):
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
sql, params = super().resolve_expression_parameter(
|
||||
compiler,
|
||||
connection,
|
||||
sql,
|
||||
param,
|
||||
compiler, connection, sql, param,
|
||||
)
|
||||
if (
|
||||
not hasattr(param, "as_sql")
|
||||
and not connection.features.has_native_json_field
|
||||
not hasattr(param, 'as_sql') and
|
||||
not connection.features.has_native_json_field
|
||||
):
|
||||
if connection.vendor == "oracle":
|
||||
if connection.vendor == 'oracle':
|
||||
value = json.loads(param)
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
if isinstance(value, (list, dict)):
|
||||
sql = sql % "JSON_QUERY"
|
||||
sql = sql % 'JSON_QUERY'
|
||||
else:
|
||||
sql = sql % "JSON_VALUE"
|
||||
elif connection.vendor == "mysql" or (
|
||||
connection.vendor == "sqlite"
|
||||
and params[0] not in connection.ops.jsonfield_datatype_values
|
||||
):
|
||||
sql = sql % 'JSON_VALUE'
|
||||
elif connection.vendor in {'sqlite', 'mysql'}:
|
||||
sql = "JSON_EXTRACT(%s, '$')"
|
||||
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
|
||||
sql = "JSON_UNQUOTE(%s)" % sql
|
||||
if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
|
||||
sql = 'JSON_UNQUOTE(%s)' % sql
|
||||
return sql, params
|
||||
|
||||
|
||||
class KeyTransformExact(JSONExact):
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == 'sqlite':
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if rhs == '%s' and rhs_params == ['null']:
|
||||
lhs, *_ = self.lhs.preprocess_lhs(compiler, connection)
|
||||
lhs = 'JSON_TYPE(%s, %%s)' % lhs
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
return super(lookups.Exact, self).process_rhs(compiler, connection)
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "oracle":
|
||||
if connection.vendor == 'oracle':
|
||||
func = []
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
for value in rhs_params:
|
||||
value = json.loads(value)
|
||||
if isinstance(value, (list, dict)):
|
||||
func.append(sql % "JSON_QUERY")
|
||||
func.append(sql % 'JSON_QUERY')
|
||||
else:
|
||||
func.append(sql % "JSON_VALUE")
|
||||
func.append(sql % 'JSON_VALUE')
|
||||
rhs = rhs % tuple(func)
|
||||
elif connection.vendor == "sqlite":
|
||||
func = []
|
||||
for value in rhs_params:
|
||||
if value in connection.ops.jsonfield_datatype_values:
|
||||
func.append("%s")
|
||||
else:
|
||||
func.append("JSON_EXTRACT(%s, '$')")
|
||||
elif connection.vendor == 'sqlite':
|
||||
func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params]
|
||||
rhs = rhs % tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if rhs_params == ["null"]:
|
||||
if rhs_params == ['null']:
|
||||
# Field has key and it's NULL.
|
||||
has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name)
|
||||
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
|
||||
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
|
||||
is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True)
|
||||
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
|
||||
return (
|
||||
"%s AND %s" % (has_key_sql, is_null_sql),
|
||||
'%s AND %s' % (has_key_sql, is_null_sql),
|
||||
tuple(has_key_params) + tuple(is_null_params),
|
||||
)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class KeyTransformIExact(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
|
||||
):
|
||||
class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIContains(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
|
||||
):
|
||||
class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
@@ -493,9 +466,7 @@ class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIStartsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
|
||||
):
|
||||
class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith):
|
||||
pass
|
||||
|
||||
|
||||
@@ -503,9 +474,7 @@ class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIEndsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
|
||||
):
|
||||
class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith):
|
||||
pass
|
||||
|
||||
|
||||
@@ -513,9 +482,7 @@ class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIRegex(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
|
||||
):
|
||||
class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex):
|
||||
pass
|
||||
|
||||
|
||||
@@ -562,6 +529,7 @@ KeyTransform.register_lookup(KeyTransformGte)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
|
||||
@@ -29,25 +29,22 @@ class FieldCacheMixin:
|
||||
|
||||
|
||||
class CheckFieldDefaultMixin:
|
||||
_default_hint = ("<valid default>", "<invalid default>")
|
||||
_default_hint = ('<valid default>', '<invalid default>')
|
||||
|
||||
def _check_default(self):
|
||||
if (
|
||||
self.has_default()
|
||||
and self.default is not None
|
||||
and not callable(self.default)
|
||||
):
|
||||
if self.has_default() and self.default is not None and not callable(self.default):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s default should be a callable instead of an instance "
|
||||
"so that it's not shared between all field instances."
|
||||
% (self.__class__.__name__,),
|
||||
"so that it's not shared between all field instances." % (
|
||||
self.__class__.__name__,
|
||||
),
|
||||
hint=(
|
||||
"Use a callable instead, e.g., use `%s` instead of "
|
||||
"`%s`." % self._default_hint
|
||||
'Use a callable instead, e.g., use `%s` instead of '
|
||||
'`%s`.' % self._default_hint
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E010",
|
||||
id='fields.E010',
|
||||
)
|
||||
]
|
||||
else:
|
||||
|
||||
@@ -13,6 +13,6 @@ class OrderWrt(fields.IntegerField):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["name"] = "_order"
|
||||
kwargs["editable"] = False
|
||||
kwargs['name'] = '_order'
|
||||
kwargs['editable'] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -74,9 +74,7 @@ from django.utils.functional import cached_property
|
||||
|
||||
class ForeignKeyDeferredAttribute(DeferredAttribute):
|
||||
def __set__(self, instance, value):
|
||||
if instance.__dict__.get(self.field.attname) != value and self.field.is_cached(
|
||||
instance
|
||||
):
|
||||
if instance.__dict__.get(self.field.attname) != value and self.field.is_cached(instance):
|
||||
self.field.delete_cached_value(instance)
|
||||
instance.__dict__[self.field.attname] = value
|
||||
|
||||
@@ -103,16 +101,14 @@ class ForwardManyToOneDescriptor:
|
||||
# related model might not be resolved yet; `self.field.model` might
|
||||
# still be a string model reference.
|
||||
return type(
|
||||
"RelatedObjectDoesNotExist",
|
||||
(self.field.remote_field.model.DoesNotExist, AttributeError),
|
||||
{
|
||||
"__module__": self.field.model.__module__,
|
||||
"__qualname__": "%s.%s.RelatedObjectDoesNotExist"
|
||||
% (
|
||||
'RelatedObjectDoesNotExist',
|
||||
(self.field.remote_field.model.DoesNotExist, AttributeError), {
|
||||
'__module__': self.field.model.__module__,
|
||||
'__qualname__': '%s.%s.RelatedObjectDoesNotExist' % (
|
||||
self.field.model.__qualname__,
|
||||
self.field.name,
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def is_cached(self, instance):
|
||||
@@ -139,12 +135,9 @@ class ForwardManyToOneDescriptor:
|
||||
# The check for len(...) == 1 is a special case that allows the query
|
||||
# to be join-less and smaller. Refs #21760.
|
||||
if remote_field.is_hidden() or len(self.field.foreign_related_fields) == 1:
|
||||
query = {
|
||||
"%s__in"
|
||||
% related_field.name: {instance_attr(inst)[0] for inst in instances}
|
||||
}
|
||||
query = {'%s__in' % related_field.name: {instance_attr(inst)[0] for inst in instances}}
|
||||
else:
|
||||
query = {"%s__in" % self.field.related_query_name(): instances}
|
||||
query = {'%s__in' % self.field.related_query_name(): instances}
|
||||
queryset = queryset.filter(**query)
|
||||
|
||||
# Since we're going to assign directly in the cache,
|
||||
@@ -153,14 +146,7 @@ class ForwardManyToOneDescriptor:
|
||||
for rel_obj in queryset:
|
||||
instance = instances_dict[rel_obj_attr(rel_obj)]
|
||||
remote_field.set_cached_value(rel_obj, instance)
|
||||
return (
|
||||
queryset,
|
||||
rel_obj_attr,
|
||||
instance_attr,
|
||||
True,
|
||||
self.field.get_cache_name(),
|
||||
False,
|
||||
)
|
||||
return queryset, rel_obj_attr, instance_attr, True, self.field.get_cache_name(), False
|
||||
|
||||
def get_object(self, instance):
|
||||
qs = self.get_queryset(instance=instance)
|
||||
@@ -187,11 +173,7 @@ class ForwardManyToOneDescriptor:
|
||||
rel_obj = self.field.get_cached_value(instance)
|
||||
except KeyError:
|
||||
has_value = None not in self.field.get_local_related_value(instance)
|
||||
ancestor_link = (
|
||||
instance._meta.get_ancestor_link(self.field.model)
|
||||
if has_value
|
||||
else None
|
||||
)
|
||||
ancestor_link = instance._meta.get_ancestor_link(self.field.model) if has_value else None
|
||||
if ancestor_link and ancestor_link.is_cached(instance):
|
||||
# An ancestor link will exist if this field is defined on a
|
||||
# multi-table inheritance parent of the instance's class.
|
||||
@@ -229,12 +211,9 @@ class ForwardManyToOneDescriptor:
|
||||
- ``value`` is the ``parent`` instance on the right of the equal sign
|
||||
"""
|
||||
# An object must be an instance of the related class.
|
||||
if value is not None and not isinstance(
|
||||
value, self.field.remote_field.model._meta.concrete_model
|
||||
):
|
||||
if value is not None and not isinstance(value, self.field.remote_field.model._meta.concrete_model):
|
||||
raise ValueError(
|
||||
'Cannot assign "%r": "%s.%s" must be a "%s" instance.'
|
||||
% (
|
||||
'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
|
||||
value,
|
||||
instance._meta.object_name,
|
||||
self.field.name,
|
||||
@@ -243,18 +222,11 @@ class ForwardManyToOneDescriptor:
|
||||
)
|
||||
elif value is not None:
|
||||
if instance._state.db is None:
|
||||
instance._state.db = router.db_for_write(
|
||||
instance.__class__, instance=value
|
||||
)
|
||||
instance._state.db = router.db_for_write(instance.__class__, instance=value)
|
||||
if value._state.db is None:
|
||||
value._state.db = router.db_for_write(
|
||||
value.__class__, instance=instance
|
||||
)
|
||||
value._state.db = router.db_for_write(value.__class__, instance=instance)
|
||||
if not router.allow_relation(value, instance):
|
||||
raise ValueError(
|
||||
'Cannot assign "%r": the current database router prevents this '
|
||||
"relation." % value
|
||||
)
|
||||
raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)
|
||||
|
||||
remote_field = self.field.remote_field
|
||||
# If we're setting the value of a OneToOneField to None, we need to clear
|
||||
@@ -342,15 +314,12 @@ class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor):
|
||||
opts = instance._meta
|
||||
# Inherited primary key fields from this object's base classes.
|
||||
inherited_pk_fields = [
|
||||
field
|
||||
for field in opts.concrete_fields
|
||||
field for field in opts.concrete_fields
|
||||
if field.primary_key and field.remote_field
|
||||
]
|
||||
for field in inherited_pk_fields:
|
||||
rel_model_pk_name = field.remote_field.model._meta.pk.attname
|
||||
raw_value = (
|
||||
getattr(value, rel_model_pk_name) if value is not None else None
|
||||
)
|
||||
raw_value = getattr(value, rel_model_pk_name) if value is not None else None
|
||||
setattr(instance, rel_model_pk_name, raw_value)
|
||||
|
||||
|
||||
@@ -377,15 +346,13 @@ class ReverseOneToOneDescriptor:
|
||||
# The exception isn't created at initialization time for the sake of
|
||||
# consistency with `ForwardManyToOneDescriptor`.
|
||||
return type(
|
||||
"RelatedObjectDoesNotExist",
|
||||
(self.related.related_model.DoesNotExist, AttributeError),
|
||||
{
|
||||
"__module__": self.related.model.__module__,
|
||||
"__qualname__": "%s.%s.RelatedObjectDoesNotExist"
|
||||
% (
|
||||
'RelatedObjectDoesNotExist',
|
||||
(self.related.related_model.DoesNotExist, AttributeError), {
|
||||
'__module__': self.related.model.__module__,
|
||||
'__qualname__': '%s.%s.RelatedObjectDoesNotExist' % (
|
||||
self.related.model.__qualname__,
|
||||
self.related.name,
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -403,7 +370,7 @@ class ReverseOneToOneDescriptor:
|
||||
rel_obj_attr = self.related.field.get_local_related_value
|
||||
instance_attr = self.related.field.get_foreign_related_value
|
||||
instances_dict = {instance_attr(inst): inst for inst in instances}
|
||||
query = {"%s__in" % self.related.field.name: instances}
|
||||
query = {'%s__in' % self.related.field.name: instances}
|
||||
queryset = queryset.filter(**query)
|
||||
|
||||
# Since we're going to assign directly in the cache,
|
||||
@@ -411,14 +378,7 @@ class ReverseOneToOneDescriptor:
|
||||
for rel_obj in queryset:
|
||||
instance = instances_dict[rel_obj_attr(rel_obj)]
|
||||
self.related.field.set_cached_value(rel_obj, instance)
|
||||
return (
|
||||
queryset,
|
||||
rel_obj_attr,
|
||||
instance_attr,
|
||||
True,
|
||||
self.related.get_cache_name(),
|
||||
False,
|
||||
)
|
||||
return queryset, rel_obj_attr, instance_attr, True, self.related.get_cache_name(), False
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
"""
|
||||
@@ -459,8 +419,10 @@ class ReverseOneToOneDescriptor:
|
||||
|
||||
if rel_obj is None:
|
||||
raise self.RelatedObjectDoesNotExist(
|
||||
"%s has no %s."
|
||||
% (instance.__class__.__name__, self.related.get_accessor_name())
|
||||
"%s has no %s." % (
|
||||
instance.__class__.__name__,
|
||||
self.related.get_accessor_name()
|
||||
)
|
||||
)
|
||||
else:
|
||||
return rel_obj
|
||||
@@ -496,8 +458,7 @@ class ReverseOneToOneDescriptor:
|
||||
elif not isinstance(value, self.related.related_model):
|
||||
# An object must be an instance of the related class.
|
||||
raise ValueError(
|
||||
'Cannot assign "%r": "%s.%s" must be a "%s" instance.'
|
||||
% (
|
||||
'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
|
||||
value,
|
||||
instance._meta.object_name,
|
||||
self.related.get_accessor_name(),
|
||||
@@ -506,25 +467,14 @@ class ReverseOneToOneDescriptor:
|
||||
)
|
||||
else:
|
||||
if instance._state.db is None:
|
||||
instance._state.db = router.db_for_write(
|
||||
instance.__class__, instance=value
|
||||
)
|
||||
instance._state.db = router.db_for_write(instance.__class__, instance=value)
|
||||
if value._state.db is None:
|
||||
value._state.db = router.db_for_write(
|
||||
value.__class__, instance=instance
|
||||
)
|
||||
value._state.db = router.db_for_write(value.__class__, instance=instance)
|
||||
if not router.allow_relation(value, instance):
|
||||
raise ValueError(
|
||||
'Cannot assign "%r": the current database router prevents this '
|
||||
"relation." % value
|
||||
)
|
||||
raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)
|
||||
|
||||
related_pk = tuple(
|
||||
getattr(instance, field.attname)
|
||||
for field in self.related.field.foreign_related_fields
|
||||
)
|
||||
# Set the value of the related field to the value of the related
|
||||
# object's related field.
|
||||
related_pk = tuple(getattr(instance, field.attname) for field in self.related.field.foreign_related_fields)
|
||||
# Set the value of the related field to the value of the related object's related field
|
||||
for index, field in enumerate(self.related.field.local_related_fields):
|
||||
setattr(value, field.attname, related_pk[index])
|
||||
|
||||
@@ -587,13 +537,13 @@ class ReverseManyToOneDescriptor:
|
||||
|
||||
def _get_set_deprecation_msg_params(self):
|
||||
return (
|
||||
"reverse side of a related set",
|
||||
'reverse side of a related set',
|
||||
self.rel.get_accessor_name(),
|
||||
)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
raise TypeError(
|
||||
"Direct assignment to the %s is prohibited. Use %s.set() instead."
|
||||
'Direct assignment to the %s is prohibited. Use %s.set() instead.'
|
||||
% self._get_set_deprecation_msg_params(),
|
||||
)
|
||||
|
||||
@@ -620,7 +570,6 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
manager = getattr(self.model, manager)
|
||||
manager_class = create_reverse_many_to_one_manager(manager.__class__, rel)
|
||||
return manager_class(self.instance)
|
||||
|
||||
do_not_call_in_templates = True
|
||||
|
||||
def _apply_rel_filters(self, queryset):
|
||||
@@ -628,9 +577,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
Filter the queryset for the instance this manager is bound to.
|
||||
"""
|
||||
db = self._db or router.db_for_read(self.model, instance=self.instance)
|
||||
empty_strings_as_null = connections[
|
||||
db
|
||||
].features.interprets_empty_strings_as_nulls
|
||||
empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
|
||||
queryset._add_hints(instance=self.instance)
|
||||
if self._db:
|
||||
queryset = queryset.using(self._db)
|
||||
@@ -638,7 +585,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
queryset = queryset.filter(**self.core_filters)
|
||||
for field in self.field.foreign_related_fields:
|
||||
val = getattr(self.instance, field.attname)
|
||||
if val is None or (val == "" and empty_strings_as_null):
|
||||
if val is None or (val == '' and empty_strings_as_null):
|
||||
return queryset.none()
|
||||
if self.field.many_to_one:
|
||||
# Guard against field-like objects such as GenericRelation
|
||||
@@ -650,34 +597,24 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
except FieldError:
|
||||
# The relationship has multiple target fields. Use a tuple
|
||||
# for related object id.
|
||||
rel_obj_id = tuple(
|
||||
[
|
||||
getattr(self.instance, target_field.attname)
|
||||
for target_field in self.field.get_path_info()[
|
||||
-1
|
||||
].target_fields
|
||||
]
|
||||
)
|
||||
rel_obj_id = tuple([
|
||||
getattr(self.instance, target_field.attname)
|
||||
for target_field in self.field.get_path_info()[-1].target_fields
|
||||
])
|
||||
else:
|
||||
rel_obj_id = getattr(self.instance, target_field.attname)
|
||||
queryset._known_related_objects = {
|
||||
self.field: {rel_obj_id: self.instance}
|
||||
}
|
||||
queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
|
||||
return queryset
|
||||
|
||||
def _remove_prefetched_objects(self):
|
||||
try:
|
||||
self.instance._prefetched_objects_cache.pop(
|
||||
self.field.remote_field.get_cache_name()
|
||||
)
|
||||
self.instance._prefetched_objects_cache.pop(self.field.remote_field.get_cache_name())
|
||||
except (AttributeError, KeyError):
|
||||
pass # nothing to clear from cache
|
||||
|
||||
def get_queryset(self):
|
||||
try:
|
||||
return self.instance._prefetched_objects_cache[
|
||||
self.field.remote_field.get_cache_name()
|
||||
]
|
||||
return self.instance._prefetched_objects_cache[self.field.remote_field.get_cache_name()]
|
||||
except (AttributeError, KeyError):
|
||||
queryset = super().get_queryset()
|
||||
return self._apply_rel_filters(queryset)
|
||||
@@ -692,7 +629,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
rel_obj_attr = self.field.get_local_related_value
|
||||
instance_attr = self.field.get_foreign_related_value
|
||||
instances_dict = {instance_attr(inst): inst for inst in instances}
|
||||
query = {"%s__in" % self.field.name: instances}
|
||||
query = {'%s__in' % self.field.name: instances}
|
||||
queryset = queryset.filter(**query)
|
||||
|
||||
# Since we just bypassed this class' get_queryset(), we must manage
|
||||
@@ -709,13 +646,9 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
|
||||
def check_and_update_obj(obj):
|
||||
if not isinstance(obj, self.model):
|
||||
raise TypeError(
|
||||
"'%s' instance expected, got %r"
|
||||
% (
|
||||
self.model._meta.object_name,
|
||||
obj,
|
||||
)
|
||||
)
|
||||
raise TypeError("'%s' instance expected, got %r" % (
|
||||
self.model._meta.object_name, obj,
|
||||
))
|
||||
setattr(obj, self.field.name, self.instance)
|
||||
|
||||
if bulk:
|
||||
@@ -728,44 +661,36 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
"the object first." % obj
|
||||
)
|
||||
pks.append(obj.pk)
|
||||
self.model._base_manager.using(db).filter(pk__in=pks).update(
|
||||
**{
|
||||
self.field.name: self.instance,
|
||||
}
|
||||
)
|
||||
self.model._base_manager.using(db).filter(pk__in=pks).update(**{
|
||||
self.field.name: self.instance,
|
||||
})
|
||||
else:
|
||||
with transaction.atomic(using=db, savepoint=False):
|
||||
for obj in objs:
|
||||
check_and_update_obj(obj)
|
||||
obj.save()
|
||||
|
||||
add.alters_data = True
|
||||
|
||||
def create(self, **kwargs):
|
||||
kwargs[self.field.name] = self.instance
|
||||
db = router.db_for_write(self.model, instance=self.instance)
|
||||
return super(RelatedManager, self.db_manager(db)).create(**kwargs)
|
||||
|
||||
create.alters_data = True
|
||||
|
||||
def get_or_create(self, **kwargs):
|
||||
kwargs[self.field.name] = self.instance
|
||||
db = router.db_for_write(self.model, instance=self.instance)
|
||||
return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
|
||||
|
||||
get_or_create.alters_data = True
|
||||
|
||||
def update_or_create(self, **kwargs):
|
||||
kwargs[self.field.name] = self.instance
|
||||
db = router.db_for_write(self.model, instance=self.instance)
|
||||
return super(RelatedManager, self.db_manager(db)).update_or_create(**kwargs)
|
||||
|
||||
update_or_create.alters_data = True
|
||||
|
||||
# remove() and clear() are only provided if the ForeignKey can have a
|
||||
# value of null.
|
||||
# remove() and clear() are only provided if the ForeignKey can have a value of null.
|
||||
if rel.field.null:
|
||||
|
||||
def remove(self, *objs, bulk=True):
|
||||
if not objs:
|
||||
return
|
||||
@@ -773,13 +698,9 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
old_ids = set()
|
||||
for obj in objs:
|
||||
if not isinstance(obj, self.model):
|
||||
raise TypeError(
|
||||
"'%s' instance expected, got %r"
|
||||
% (
|
||||
self.model._meta.object_name,
|
||||
obj,
|
||||
)
|
||||
)
|
||||
raise TypeError("'%s' instance expected, got %r" % (
|
||||
self.model._meta.object_name, obj,
|
||||
))
|
||||
# Is obj actually part of this descriptor set?
|
||||
if self.field.get_local_related_value(obj) == val:
|
||||
old_ids.add(obj.pk)
|
||||
@@ -788,12 +709,10 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
"%r is not related to %r." % (obj, self.instance)
|
||||
)
|
||||
self._clear(self.filter(pk__in=old_ids), bulk)
|
||||
|
||||
remove.alters_data = True
|
||||
|
||||
def clear(self, *, bulk=True):
|
||||
self._clear(self, bulk)
|
||||
|
||||
clear.alters_data = True
|
||||
|
||||
def _clear(self, queryset, bulk):
|
||||
@@ -808,7 +727,6 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
for obj in queryset:
|
||||
setattr(obj, self.field.name, None)
|
||||
obj.save(update_fields=[self.field.name])
|
||||
|
||||
_clear.alters_data = True
|
||||
|
||||
def set(self, objs, *, bulk=True, clear=False):
|
||||
@@ -835,7 +753,6 @@ def create_reverse_many_to_one_manager(superclass, rel):
|
||||
self.add(*new_objs, bulk=bulk)
|
||||
else:
|
||||
self.add(*objs, bulk=bulk)
|
||||
|
||||
set.alters_data = True
|
||||
|
||||
return RelatedManager
|
||||
@@ -882,8 +799,7 @@ class ManyToManyDescriptor(ReverseManyToOneDescriptor):
|
||||
|
||||
def _get_set_deprecation_msg_params(self):
|
||||
return (
|
||||
"%s side of a many-to-many set"
|
||||
% ("reverse" if self.reverse else "forward"),
|
||||
'%s side of a many-to-many set' % ('reverse' if self.reverse else 'forward'),
|
||||
self.rel.get_accessor_name() if self.reverse else self.field.name,
|
||||
)
|
||||
|
||||
@@ -926,51 +842,42 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
self.core_filters = {}
|
||||
self.pk_field_names = {}
|
||||
for lh_field, rh_field in self.source_field.related_fields:
|
||||
core_filter_key = "%s__%s" % (self.query_field_name, rh_field.name)
|
||||
core_filter_key = '%s__%s' % (self.query_field_name, rh_field.name)
|
||||
self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)
|
||||
self.pk_field_names[lh_field.name] = rh_field.name
|
||||
|
||||
self.related_val = self.source_field.get_foreign_related_value(instance)
|
||||
if None in self.related_val:
|
||||
raise ValueError(
|
||||
'"%r" needs to have a value for field "%s" before '
|
||||
"this many-to-many relationship can be used."
|
||||
% (instance, self.pk_field_names[self.source_field_name])
|
||||
)
|
||||
raise ValueError('"%r" needs to have a value for field "%s" before '
|
||||
'this many-to-many relationship can be used.' %
|
||||
(instance, self.pk_field_names[self.source_field_name]))
|
||||
# Even if this relation is not to pk, we require still pk value.
|
||||
# The wish is that the instance has been already saved to DB,
|
||||
# although having a pk value isn't a guarantee of that.
|
||||
if instance.pk is None:
|
||||
raise ValueError(
|
||||
"%r instance needs to have a primary key value before "
|
||||
"a many-to-many relationship can be used."
|
||||
% instance.__class__.__name__
|
||||
)
|
||||
raise ValueError("%r instance needs to have a primary key value before "
|
||||
"a many-to-many relationship can be used." %
|
||||
instance.__class__.__name__)
|
||||
|
||||
def __call__(self, *, manager):
|
||||
manager = getattr(self.model, manager)
|
||||
manager_class = create_forward_many_to_many_manager(
|
||||
manager.__class__, rel, reverse
|
||||
)
|
||||
manager_class = create_forward_many_to_many_manager(manager.__class__, rel, reverse)
|
||||
return manager_class(instance=self.instance)
|
||||
|
||||
do_not_call_in_templates = True
|
||||
|
||||
def _build_remove_filters(self, removed_vals):
|
||||
filters = Q((self.source_field_name, self.related_val))
|
||||
filters = Q(**{self.source_field_name: self.related_val})
|
||||
# No need to add a subquery condition if removed_vals is a QuerySet without
|
||||
# filters.
|
||||
removed_vals_filters = (
|
||||
not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
|
||||
)
|
||||
removed_vals_filters = (not isinstance(removed_vals, QuerySet) or
|
||||
removed_vals._has_filters())
|
||||
if removed_vals_filters:
|
||||
filters &= Q((f"{self.target_field_name}__in", removed_vals))
|
||||
filters &= Q(**{'%s__in' % self.target_field_name: removed_vals})
|
||||
if self.symmetrical:
|
||||
symmetrical_filters = Q((self.target_field_name, self.related_val))
|
||||
symmetrical_filters = Q(**{self.target_field_name: self.related_val})
|
||||
if removed_vals_filters:
|
||||
symmetrical_filters &= Q(
|
||||
(f"{self.source_field_name}__in", removed_vals)
|
||||
)
|
||||
**{'%s__in' % self.source_field_name: removed_vals})
|
||||
filters |= symmetrical_filters
|
||||
return filters
|
||||
|
||||
@@ -1004,7 +911,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
queryset._add_hints(instance=instances[0])
|
||||
queryset = queryset.using(queryset._db or self._db)
|
||||
|
||||
query = {"%s__in" % self.query_field_name: instances}
|
||||
query = {'%s__in' % self.query_field_name: instances}
|
||||
queryset = queryset._next_is_sticky().filter(**query)
|
||||
|
||||
# M2M: need to annotate the query in order to get the primary model
|
||||
@@ -1018,18 +925,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
join_table = fk.model._meta.db_table
|
||||
connection = connections[queryset.db]
|
||||
qn = connection.ops.quote_name
|
||||
queryset = queryset.extra(
|
||||
select={
|
||||
"_prefetch_related_val_%s"
|
||||
% f.attname: "%s.%s"
|
||||
% (qn(join_table), qn(f.column))
|
||||
for f in fk.local_related_fields
|
||||
}
|
||||
)
|
||||
queryset = queryset.extra(select={
|
||||
'_prefetch_related_val_%s' % f.attname:
|
||||
'%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
|
||||
return (
|
||||
queryset,
|
||||
lambda result: tuple(
|
||||
getattr(result, "_prefetch_related_val_%s" % f.attname)
|
||||
getattr(result, '_prefetch_related_val_%s' % f.attname)
|
||||
for f in fk.local_related_fields
|
||||
),
|
||||
lambda inst: tuple(
|
||||
@@ -1046,9 +948,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
db = router.db_for_write(self.through, instance=self.instance)
|
||||
with transaction.atomic(using=db, savepoint=False):
|
||||
self._add_items(
|
||||
self.source_field_name,
|
||||
self.target_field_name,
|
||||
*objs,
|
||||
self.source_field_name, self.target_field_name, *objs,
|
||||
through_defaults=through_defaults,
|
||||
)
|
||||
# If this is a symmetrical m2m relation to self, add the mirror
|
||||
@@ -1060,41 +960,30 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
*objs,
|
||||
through_defaults=through_defaults,
|
||||
)
|
||||
|
||||
add.alters_data = True
|
||||
|
||||
def remove(self, *objs):
|
||||
self._remove_prefetched_objects()
|
||||
self._remove_items(self.source_field_name, self.target_field_name, *objs)
|
||||
|
||||
remove.alters_data = True
|
||||
|
||||
def clear(self):
|
||||
db = router.db_for_write(self.through, instance=self.instance)
|
||||
with transaction.atomic(using=db, savepoint=False):
|
||||
signals.m2m_changed.send(
|
||||
sender=self.through,
|
||||
action="pre_clear",
|
||||
instance=self.instance,
|
||||
reverse=self.reverse,
|
||||
model=self.model,
|
||||
pk_set=None,
|
||||
using=db,
|
||||
sender=self.through, action="pre_clear",
|
||||
instance=self.instance, reverse=self.reverse,
|
||||
model=self.model, pk_set=None, using=db,
|
||||
)
|
||||
self._remove_prefetched_objects()
|
||||
filters = self._build_remove_filters(super().get_queryset().using(db))
|
||||
self.through._default_manager.using(db).filter(filters).delete()
|
||||
|
||||
signals.m2m_changed.send(
|
||||
sender=self.through,
|
||||
action="post_clear",
|
||||
instance=self.instance,
|
||||
reverse=self.reverse,
|
||||
model=self.model,
|
||||
pk_set=None,
|
||||
using=db,
|
||||
sender=self.through, action="post_clear",
|
||||
instance=self.instance, reverse=self.reverse,
|
||||
model=self.model, pk_set=None, using=db,
|
||||
)
|
||||
|
||||
clear.alters_data = True
|
||||
|
||||
def set(self, objs, *, clear=False, through_defaults=None):
|
||||
@@ -1108,11 +997,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
self.clear()
|
||||
self.add(*objs, through_defaults=through_defaults)
|
||||
else:
|
||||
old_ids = set(
|
||||
self.using(db).values_list(
|
||||
self.target_field.target_field.attname, flat=True
|
||||
)
|
||||
)
|
||||
old_ids = set(self.using(db).values_list(self.target_field.target_field.attname, flat=True))
|
||||
|
||||
new_objs = []
|
||||
for obj in objs:
|
||||
@@ -1128,7 +1013,6 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
|
||||
self.remove(*old_ids)
|
||||
self.add(*new_objs, through_defaults=through_defaults)
|
||||
|
||||
set.alters_data = True
|
||||
|
||||
def create(self, *, through_defaults=None, **kwargs):
|
||||
@@ -1136,33 +1020,26 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs)
|
||||
self.add(new_obj, through_defaults=through_defaults)
|
||||
return new_obj
|
||||
|
||||
create.alters_data = True
|
||||
|
||||
def get_or_create(self, *, through_defaults=None, **kwargs):
|
||||
db = router.db_for_write(self.instance.__class__, instance=self.instance)
|
||||
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(
|
||||
**kwargs
|
||||
)
|
||||
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(**kwargs)
|
||||
# We only need to add() if created because if we got an object back
|
||||
# from get() then the relationship already exists.
|
||||
if created:
|
||||
self.add(obj, through_defaults=through_defaults)
|
||||
return obj, created
|
||||
|
||||
get_or_create.alters_data = True
|
||||
|
||||
def update_or_create(self, *, through_defaults=None, **kwargs):
|
||||
db = router.db_for_write(self.instance.__class__, instance=self.instance)
|
||||
obj, created = super(
|
||||
ManyRelatedManager, self.db_manager(db)
|
||||
).update_or_create(**kwargs)
|
||||
obj, created = super(ManyRelatedManager, self.db_manager(db)).update_or_create(**kwargs)
|
||||
# We only need to add() if created because if we got an object back
|
||||
# from get() then the relationship already exists.
|
||||
if created:
|
||||
self.add(obj, through_defaults=through_defaults)
|
||||
return obj, created
|
||||
|
||||
update_or_create.alters_data = True
|
||||
|
||||
def _get_target_ids(self, target_field_name, objs):
|
||||
@@ -1170,7 +1047,6 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
Return the set of ids of `objs` that the target field references.
|
||||
"""
|
||||
from django.db.models import Model
|
||||
|
||||
target_ids = set()
|
||||
target_field = self.through._meta.get_field(target_field_name)
|
||||
for obj in objs:
|
||||
@@ -1178,42 +1054,36 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
if not router.allow_relation(obj, self.instance):
|
||||
raise ValueError(
|
||||
'Cannot add "%r": instance is on database "%s", '
|
||||
'value is on database "%s"'
|
||||
% (obj, self.instance._state.db, obj._state.db)
|
||||
'value is on database "%s"' %
|
||||
(obj, self.instance._state.db, obj._state.db)
|
||||
)
|
||||
target_id = target_field.get_foreign_related_value(obj)[0]
|
||||
if target_id is None:
|
||||
raise ValueError(
|
||||
'Cannot add "%r": the value for field "%s" is None'
|
||||
% (obj, target_field_name)
|
||||
'Cannot add "%r": the value for field "%s" is None' %
|
||||
(obj, target_field_name)
|
||||
)
|
||||
target_ids.add(target_id)
|
||||
elif isinstance(obj, Model):
|
||||
raise TypeError(
|
||||
"'%s' instance expected, got %r"
|
||||
% (self.model._meta.object_name, obj)
|
||||
"'%s' instance expected, got %r" %
|
||||
(self.model._meta.object_name, obj)
|
||||
)
|
||||
else:
|
||||
target_ids.add(target_field.get_prep_value(obj))
|
||||
return target_ids
|
||||
|
||||
def _get_missing_target_ids(
|
||||
self, source_field_name, target_field_name, db, target_ids
|
||||
):
|
||||
def _get_missing_target_ids(self, source_field_name, target_field_name, db, target_ids):
|
||||
"""
|
||||
Return the subset of ids of `objs` that aren't already assigned to
|
||||
this relationship.
|
||||
"""
|
||||
vals = (
|
||||
self.through._default_manager.using(db)
|
||||
.values_list(target_field_name, flat=True)
|
||||
.filter(
|
||||
**{
|
||||
source_field_name: self.related_val[0],
|
||||
"%s__in" % target_field_name: target_ids,
|
||||
}
|
||||
)
|
||||
)
|
||||
vals = self.through._default_manager.using(db).values_list(
|
||||
target_field_name, flat=True
|
||||
).filter(**{
|
||||
source_field_name: self.related_val[0],
|
||||
'%s__in' % target_field_name: target_ids,
|
||||
})
|
||||
return target_ids.difference(vals)
|
||||
|
||||
def _get_add_plan(self, db, source_field_name):
|
||||
@@ -1231,53 +1101,39 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
# user-defined intermediary models as they could have other fields
|
||||
# causing conflicts which must be surfaced.
|
||||
can_ignore_conflicts = (
|
||||
connections[db].features.supports_ignore_conflicts
|
||||
and self.through._meta.auto_created is not False
|
||||
connections[db].features.supports_ignore_conflicts and
|
||||
self.through._meta.auto_created is not False
|
||||
)
|
||||
# Don't send the signal when inserting duplicate data row
|
||||
# for symmetrical reverse entries.
|
||||
must_send_signals = (
|
||||
self.reverse or source_field_name == self.source_field_name
|
||||
) and (signals.m2m_changed.has_listeners(self.through))
|
||||
must_send_signals = (self.reverse or source_field_name == self.source_field_name) and (
|
||||
signals.m2m_changed.has_listeners(self.through)
|
||||
)
|
||||
# Fast addition through bulk insertion can only be performed
|
||||
# if no m2m_changed listeners are connected for self.through
|
||||
# as they require the added set of ids to be provided via
|
||||
# pk_set.
|
||||
return (
|
||||
can_ignore_conflicts,
|
||||
must_send_signals,
|
||||
(can_ignore_conflicts and not must_send_signals),
|
||||
)
|
||||
return can_ignore_conflicts, must_send_signals, (can_ignore_conflicts and not must_send_signals)
|
||||
|
||||
def _add_items(
|
||||
self, source_field_name, target_field_name, *objs, through_defaults=None
|
||||
):
|
||||
def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None):
|
||||
# source_field_name: the PK fieldname in join table for the source object
|
||||
# target_field_name: the PK fieldname in join table for the target object
|
||||
# *objs - objects to add. Either object instances, or primary keys
|
||||
# of object instances.
|
||||
# *objs - objects to add. Either object instances, or primary keys of object instances.
|
||||
if not objs:
|
||||
return
|
||||
|
||||
through_defaults = dict(resolve_callables(through_defaults or {}))
|
||||
target_ids = self._get_target_ids(target_field_name, objs)
|
||||
db = router.db_for_write(self.through, instance=self.instance)
|
||||
can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(
|
||||
db, source_field_name
|
||||
)
|
||||
can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name)
|
||||
if can_fast_add:
|
||||
self.through._default_manager.using(db).bulk_create(
|
||||
[
|
||||
self.through(
|
||||
**{
|
||||
"%s_id" % source_field_name: self.related_val[0],
|
||||
"%s_id" % target_field_name: target_id,
|
||||
}
|
||||
)
|
||||
for target_id in target_ids
|
||||
],
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
self.through._default_manager.using(db).bulk_create([
|
||||
self.through(**{
|
||||
'%s_id' % source_field_name: self.related_val[0],
|
||||
'%s_id' % target_field_name: target_id,
|
||||
})
|
||||
for target_id in target_ids
|
||||
], ignore_conflicts=True)
|
||||
return
|
||||
|
||||
missing_target_ids = self._get_missing_target_ids(
|
||||
@@ -1286,38 +1142,24 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
with transaction.atomic(using=db, savepoint=False):
|
||||
if must_send_signals:
|
||||
signals.m2m_changed.send(
|
||||
sender=self.through,
|
||||
action="pre_add",
|
||||
instance=self.instance,
|
||||
reverse=self.reverse,
|
||||
model=self.model,
|
||||
pk_set=missing_target_ids,
|
||||
using=db,
|
||||
sender=self.through, action='pre_add',
|
||||
instance=self.instance, reverse=self.reverse,
|
||||
model=self.model, pk_set=missing_target_ids, using=db,
|
||||
)
|
||||
# Add the ones that aren't there already.
|
||||
self.through._default_manager.using(db).bulk_create(
|
||||
[
|
||||
self.through(
|
||||
**through_defaults,
|
||||
**{
|
||||
"%s_id" % source_field_name: self.related_val[0],
|
||||
"%s_id" % target_field_name: target_id,
|
||||
},
|
||||
)
|
||||
for target_id in missing_target_ids
|
||||
],
|
||||
ignore_conflicts=can_ignore_conflicts,
|
||||
)
|
||||
self.through._default_manager.using(db).bulk_create([
|
||||
self.through(**through_defaults, **{
|
||||
'%s_id' % source_field_name: self.related_val[0],
|
||||
'%s_id' % target_field_name: target_id,
|
||||
})
|
||||
for target_id in missing_target_ids
|
||||
], ignore_conflicts=can_ignore_conflicts)
|
||||
|
||||
if must_send_signals:
|
||||
signals.m2m_changed.send(
|
||||
sender=self.through,
|
||||
action="post_add",
|
||||
instance=self.instance,
|
||||
reverse=self.reverse,
|
||||
model=self.model,
|
||||
pk_set=missing_target_ids,
|
||||
using=db,
|
||||
sender=self.through, action='post_add',
|
||||
instance=self.instance, reverse=self.reverse,
|
||||
model=self.model, pk_set=missing_target_ids, using=db,
|
||||
)
|
||||
|
||||
def _remove_items(self, source_field_name, target_field_name, *objs):
|
||||
@@ -1341,32 +1183,23 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
|
||||
with transaction.atomic(using=db, savepoint=False):
|
||||
# Send a signal to the other end if need be.
|
||||
signals.m2m_changed.send(
|
||||
sender=self.through,
|
||||
action="pre_remove",
|
||||
instance=self.instance,
|
||||
reverse=self.reverse,
|
||||
model=self.model,
|
||||
pk_set=old_ids,
|
||||
using=db,
|
||||
sender=self.through, action="pre_remove",
|
||||
instance=self.instance, reverse=self.reverse,
|
||||
model=self.model, pk_set=old_ids, using=db,
|
||||
)
|
||||
target_model_qs = super().get_queryset()
|
||||
if target_model_qs._has_filters():
|
||||
old_vals = target_model_qs.using(db).filter(
|
||||
**{"%s__in" % self.target_field.target_field.attname: old_ids}
|
||||
)
|
||||
old_vals = target_model_qs.using(db).filter(**{
|
||||
'%s__in' % self.target_field.target_field.attname: old_ids})
|
||||
else:
|
||||
old_vals = old_ids
|
||||
filters = self._build_remove_filters(old_vals)
|
||||
self.through._default_manager.using(db).filter(filters).delete()
|
||||
|
||||
signals.m2m_changed.send(
|
||||
sender=self.through,
|
||||
action="post_remove",
|
||||
instance=self.instance,
|
||||
reverse=self.reverse,
|
||||
model=self.model,
|
||||
pk_set=old_ids,
|
||||
using=db,
|
||||
sender=self.through, action="post_remove",
|
||||
instance=self.instance, reverse=self.reverse,
|
||||
model=self.model, pk_set=old_ids, using=db,
|
||||
)
|
||||
|
||||
return ManyRelatedManager
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
|
||||
@@ -13,40 +8,29 @@ class MultiColSource:
|
||||
contains_aggregate = False
|
||||
|
||||
def __init__(self, alias, targets, sources, field):
|
||||
self.targets, self.sources, self.field, self.alias = (
|
||||
targets,
|
||||
sources,
|
||||
field,
|
||||
alias,
|
||||
)
|
||||
self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
|
||||
self.output_field = self.field
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
|
||||
return "{}({}, {})".format(
|
||||
self.__class__.__name__, self.alias, self.field)
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
return self.__class__(
|
||||
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
|
||||
)
|
||||
return self.__class__(relabels.get(self.alias, self.alias),
|
||||
self.targets, self.sources, self.field)
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
|
||||
def get_normalized_value(value, lhs):
|
||||
from django.db.models import Model
|
||||
|
||||
if isinstance(value, Model):
|
||||
value_list = []
|
||||
sources = lhs.output_field.get_path_info()[-1].target_fields
|
||||
for source in sources:
|
||||
while not isinstance(value, source.model) and source.remote_field:
|
||||
source = source.remote_field.model._meta.get_field(
|
||||
source.remote_field.field_name
|
||||
)
|
||||
source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
|
||||
try:
|
||||
value_list.append(getattr(value, source.attname))
|
||||
except AttributeError:
|
||||
@@ -68,26 +52,20 @@ class RelatedIn(In):
|
||||
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
|
||||
# doesn't have validation for non-integers, so we must run validation
|
||||
# using the target field.
|
||||
if hasattr(self.lhs.output_field, "get_path_info"):
|
||||
if hasattr(self.lhs.output_field, 'get_path_info'):
|
||||
# Run the target field's get_prep_value. We can safely assume there is
|
||||
# only one as we don't get to the direct value branch otherwise.
|
||||
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[
|
||||
-1
|
||||
]
|
||||
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
|
||||
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, MultiColSource):
|
||||
# For multicolumn lookups we need to build a multicolumn where clause.
|
||||
# This clause is either a SubqueryConstraint (for values that need
|
||||
# to be compiled to SQL) or an OR-combined list of
|
||||
# (col1 = val1 AND col2 = val2 AND ...) clauses.
|
||||
# This clause is either a SubqueryConstraint (for values that need to be compiled to
|
||||
# SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
|
||||
from django.db.models.sql.where import (
|
||||
AND,
|
||||
OR,
|
||||
SubqueryConstraint,
|
||||
WhereNode,
|
||||
AND, OR, SubqueryConstraint, WhereNode,
|
||||
)
|
||||
|
||||
root_constraint = WhereNode(connector=OR)
|
||||
@@ -95,35 +73,24 @@ class RelatedIn(In):
|
||||
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
|
||||
for value in values:
|
||||
value_constraint = WhereNode()
|
||||
for source, target, val in zip(
|
||||
self.lhs.sources, self.lhs.targets, value
|
||||
):
|
||||
lookup_class = target.get_lookup("exact")
|
||||
lookup = lookup_class(
|
||||
target.get_col(self.lhs.alias, source), val
|
||||
)
|
||||
for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
|
||||
lookup_class = target.get_lookup('exact')
|
||||
lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
|
||||
value_constraint.add(lookup, AND)
|
||||
root_constraint.add(value_constraint, OR)
|
||||
else:
|
||||
root_constraint.add(
|
||||
SubqueryConstraint(
|
||||
self.lhs.alias,
|
||||
[target.column for target in self.lhs.targets],
|
||||
[source.name for source in self.lhs.sources],
|
||||
self.rhs,
|
||||
),
|
||||
AND,
|
||||
)
|
||||
self.lhs.alias, [target.column for target in self.lhs.targets],
|
||||
[source.name for source in self.lhs.sources], self.rhs),
|
||||
AND)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
else:
|
||||
if not getattr(self.rhs, "has_select_fields", True) and not getattr(
|
||||
self.lhs.field.target_field, "primary_key", False
|
||||
):
|
||||
if (not getattr(self.rhs, 'has_select_fields', True) and
|
||||
not getattr(self.lhs.field.target_field, 'primary_key', False)):
|
||||
self.rhs.clear_select_clause()
|
||||
if (
|
||||
getattr(self.lhs.output_field, "primary_key", False)
|
||||
and self.lhs.output_field.model == self.rhs.model
|
||||
):
|
||||
if (getattr(self.lhs.output_field, 'primary_key', False) and
|
||||
self.lhs.output_field.model == self.rhs.model):
|
||||
# A case like Restaurant.objects.filter(place__in=restaurant_qs),
|
||||
# where place is a OneToOneField and the primary key of
|
||||
# Restaurant.
|
||||
@@ -136,21 +103,17 @@ class RelatedIn(In):
|
||||
|
||||
class RelatedLookupMixin:
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, MultiColSource) and not hasattr(
|
||||
self.rhs, "resolve_expression"
|
||||
):
|
||||
if not isinstance(self.lhs, MultiColSource) and not hasattr(self.rhs, 'resolve_expression'):
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
|
||||
# We need to run the related field's get_prep_value(). Consider case
|
||||
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
|
||||
# doesn't have validation for non-integers, so we must run validation
|
||||
# using the target field.
|
||||
if self.prepare_rhs and hasattr(self.lhs.output_field, "get_path_info"):
|
||||
if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_path_info'):
|
||||
# Get the target field. We can safely assume there is only one
|
||||
# as we don't get to the direct value branch otherwise.
|
||||
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[
|
||||
-1
|
||||
]
|
||||
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
|
||||
self.rhs = target_field.get_prep_value(self.rhs)
|
||||
|
||||
return super().get_prep_lookup()
|
||||
@@ -160,15 +123,11 @@ class RelatedLookupMixin:
|
||||
assert self.rhs_is_direct_value()
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)
|
||||
from django.db.models.sql.where import AND, WhereNode
|
||||
|
||||
root_constraint = WhereNode()
|
||||
for target, source, val in zip(
|
||||
self.lhs.targets, self.lhs.sources, self.rhs
|
||||
):
|
||||
for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
|
||||
lookup_class = target.get_lookup(self.lookup_name)
|
||||
root_constraint.add(
|
||||
lookup_class(target.get_col(self.lhs.alias, source), val), AND
|
||||
)
|
||||
lookup_class(target.get_col(self.lhs.alias, source), val), AND)
|
||||
return root_constraint.as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
@@ -36,16 +36,8 @@ class ForeignObjectRel(FieldCacheMixin):
|
||||
null = True
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
def __init__(self, field, to, related_name=None, related_query_name=None,
|
||||
limit_choices_to=None, parent_link=False, on_delete=None):
|
||||
self.field = field
|
||||
self.model = to
|
||||
self.related_name = related_name
|
||||
@@ -81,18 +73,14 @@ class ForeignObjectRel(FieldCacheMixin):
|
||||
"""
|
||||
target_fields = self.get_path_info()[-1].target_fields
|
||||
if len(target_fields) > 1:
|
||||
raise exceptions.FieldError(
|
||||
"Can't use target_field for multicolumn relations."
|
||||
)
|
||||
raise exceptions.FieldError("Can't use target_field for multicolumn relations.")
|
||||
return target_fields[0]
|
||||
|
||||
@cached_property
|
||||
def related_model(self):
|
||||
if not self.field.model:
|
||||
raise AttributeError(
|
||||
"This property can't be accessed before self.field.contribute_to_class "
|
||||
"has been called."
|
||||
)
|
||||
"This property can't be accessed before self.field.contribute_to_class has been called.")
|
||||
return self.field.model
|
||||
|
||||
@cached_property
|
||||
@@ -122,7 +110,7 @@ class ForeignObjectRel(FieldCacheMixin):
|
||||
return self.field.db_type
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s.%s>" % (
|
||||
return '<%s: %s.%s>' % (
|
||||
type(self).__name__,
|
||||
self.related_model._meta.app_label,
|
||||
self.related_model._meta.model_name,
|
||||
@@ -151,11 +139,8 @@ class ForeignObjectRel(FieldCacheMixin):
|
||||
return hash(self.identity)
|
||||
|
||||
def get_choices(
|
||||
self,
|
||||
include_blank=True,
|
||||
blank_choice=BLANK_CHOICE_DASH,
|
||||
limit_choices_to=None,
|
||||
ordering=(),
|
||||
self, include_blank=True, blank_choice=BLANK_CHOICE_DASH,
|
||||
limit_choices_to=None, ordering=(),
|
||||
):
|
||||
"""
|
||||
Return choices with a default blank choices included, for use
|
||||
@@ -168,17 +153,19 @@ class ForeignObjectRel(FieldCacheMixin):
|
||||
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
|
||||
if ordering:
|
||||
qs = qs.order_by(*ordering)
|
||||
return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
|
||||
return (blank_choice if include_blank else []) + [
|
||||
(x.pk, str(x)) for x in qs
|
||||
]
|
||||
|
||||
def is_hidden(self):
|
||||
"""Should the related object be hidden?"""
|
||||
return bool(self.related_name) and self.related_name[-1] == "+"
|
||||
return bool(self.related_name) and self.related_name[-1] == '+'
|
||||
|
||||
def get_joining_columns(self):
|
||||
return self.field.get_reverse_joining_columns()
|
||||
|
||||
def get_extra_restriction(self, alias, related_alias):
|
||||
return self.field.get_extra_restriction(related_alias, alias)
|
||||
def get_extra_restriction(self, where_class, alias, related_alias):
|
||||
return self.field.get_extra_restriction(where_class, related_alias, alias)
|
||||
|
||||
def set_field_name(self):
|
||||
"""
|
||||
@@ -200,13 +187,12 @@ class ForeignObjectRel(FieldCacheMixin):
|
||||
opts = model._meta if model else self.related_model._meta
|
||||
model = model or self.related_model
|
||||
if self.multiple:
|
||||
# If this is a symmetrical m2m relation on self, there is no
|
||||
# reverse accessor.
|
||||
# If this is a symmetrical m2m relation on self, there is no reverse accessor.
|
||||
if self.symmetrical and model == self.model:
|
||||
return None
|
||||
if self.related_name:
|
||||
return self.related_name
|
||||
return opts.model_name + ("_set" if self.multiple else "")
|
||||
return opts.model_name + ('_set' if self.multiple else '')
|
||||
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
return self.field.get_reverse_path_info(filtered_relation)
|
||||
@@ -234,20 +220,10 @@ class ManyToOneRel(ForeignObjectRel):
|
||||
reverse relations into actual fields.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
|
||||
limit_choices_to=None, parent_link=False, on_delete=None):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
field, to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
@@ -259,7 +235,7 @@ class ManyToOneRel(ForeignObjectRel):
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state.pop("related_model", None)
|
||||
state.pop('related_model', None)
|
||||
return state
|
||||
|
||||
@property
|
||||
@@ -272,9 +248,7 @@ class ManyToOneRel(ForeignObjectRel):
|
||||
"""
|
||||
field = self.model._meta.get_field(self.field_name)
|
||||
if not field.concrete:
|
||||
raise exceptions.FieldDoesNotExist(
|
||||
"No related field named '%s'" % self.field_name
|
||||
)
|
||||
raise exceptions.FieldDoesNotExist("No related field named '%s'" % self.field_name)
|
||||
return field
|
||||
|
||||
def set_field_name(self):
|
||||
@@ -289,21 +263,10 @@ class OneToOneRel(ManyToOneRel):
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
|
||||
limit_choices_to=None, parent_link=False, on_delete=None):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
field, to, field_name,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
@@ -322,21 +285,11 @@ class ManyToManyRel(ForeignObjectRel):
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
symmetrical=True,
|
||||
through=None,
|
||||
through_fields=None,
|
||||
db_constraint=True,
|
||||
):
|
||||
def __init__(self, field, to, related_name=None, related_query_name=None,
|
||||
limit_choices_to=None, symmetrical=True, through=None,
|
||||
through_fields=None, db_constraint=True):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
field, to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
@@ -357,7 +310,7 @@ class ManyToManyRel(ForeignObjectRel):
|
||||
def identity(self):
|
||||
return super().identity + (
|
||||
self.through,
|
||||
make_hashable(self.through_fields),
|
||||
self.through_fields,
|
||||
self.db_constraint,
|
||||
)
|
||||
|
||||
@@ -371,7 +324,7 @@ class ManyToManyRel(ForeignObjectRel):
|
||||
field = opts.get_field(self.through_fields[0])
|
||||
else:
|
||||
for field in opts.fields:
|
||||
rel = getattr(field, "remote_field", None)
|
||||
rel = getattr(field, 'remote_field', None)
|
||||
if rel and rel.model == self.model:
|
||||
break
|
||||
return field.foreign_related_fields[0]
|
||||
|
||||
@@ -1,190 +1,46 @@
|
||||
from .comparison import Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf
|
||||
from .comparison import (
|
||||
Cast, Coalesce, Collate, Greatest, JSONObject, Least, NullIf,
|
||||
)
|
||||
from .datetime import (
|
||||
Extract,
|
||||
ExtractDay,
|
||||
ExtractHour,
|
||||
ExtractIsoWeekDay,
|
||||
ExtractIsoYear,
|
||||
ExtractMinute,
|
||||
ExtractMonth,
|
||||
ExtractQuarter,
|
||||
ExtractSecond,
|
||||
ExtractWeek,
|
||||
ExtractWeekDay,
|
||||
ExtractYear,
|
||||
Now,
|
||||
Trunc,
|
||||
TruncDate,
|
||||
TruncDay,
|
||||
TruncHour,
|
||||
TruncMinute,
|
||||
TruncMonth,
|
||||
TruncQuarter,
|
||||
TruncSecond,
|
||||
TruncTime,
|
||||
TruncWeek,
|
||||
Extract, ExtractDay, ExtractHour, ExtractIsoWeekDay, ExtractIsoYear,
|
||||
ExtractMinute, ExtractMonth, ExtractQuarter, ExtractSecond, ExtractWeek,
|
||||
ExtractWeekDay, ExtractYear, Now, Trunc, TruncDate, TruncDay, TruncHour,
|
||||
TruncMinute, TruncMonth, TruncQuarter, TruncSecond, TruncTime, TruncWeek,
|
||||
TruncYear,
|
||||
)
|
||||
from .math import (
|
||||
Abs,
|
||||
ACos,
|
||||
ASin,
|
||||
ATan,
|
||||
ATan2,
|
||||
Ceil,
|
||||
Cos,
|
||||
Cot,
|
||||
Degrees,
|
||||
Exp,
|
||||
Floor,
|
||||
Ln,
|
||||
Log,
|
||||
Mod,
|
||||
Pi,
|
||||
Power,
|
||||
Radians,
|
||||
Random,
|
||||
Round,
|
||||
Sign,
|
||||
Sin,
|
||||
Sqrt,
|
||||
Tan,
|
||||
Abs, ACos, ASin, ATan, ATan2, Ceil, Cos, Cot, Degrees, Exp, Floor, Ln, Log,
|
||||
Mod, Pi, Power, Radians, Random, Round, Sign, Sin, Sqrt, Tan,
|
||||
)
|
||||
from .text import (
|
||||
MD5,
|
||||
SHA1,
|
||||
SHA224,
|
||||
SHA256,
|
||||
SHA384,
|
||||
SHA512,
|
||||
Chr,
|
||||
Concat,
|
||||
ConcatPair,
|
||||
Left,
|
||||
Length,
|
||||
Lower,
|
||||
LPad,
|
||||
LTrim,
|
||||
Ord,
|
||||
Repeat,
|
||||
Replace,
|
||||
Reverse,
|
||||
Right,
|
||||
RPad,
|
||||
RTrim,
|
||||
StrIndex,
|
||||
Substr,
|
||||
Trim,
|
||||
Upper,
|
||||
MD5, SHA1, SHA224, SHA256, SHA384, SHA512, Chr, Concat, ConcatPair, Left,
|
||||
Length, Lower, LPad, LTrim, Ord, Repeat, Replace, Reverse, Right, RPad,
|
||||
RTrim, StrIndex, Substr, Trim, Upper,
|
||||
)
|
||||
from .window import (
|
||||
CumeDist,
|
||||
DenseRank,
|
||||
FirstValue,
|
||||
Lag,
|
||||
LastValue,
|
||||
Lead,
|
||||
NthValue,
|
||||
Ntile,
|
||||
PercentRank,
|
||||
Rank,
|
||||
RowNumber,
|
||||
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
|
||||
PercentRank, Rank, RowNumber,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# comparison and conversion
|
||||
"Cast",
|
||||
"Coalesce",
|
||||
"Collate",
|
||||
"Greatest",
|
||||
"JSONObject",
|
||||
"Least",
|
||||
"NullIf",
|
||||
'Cast', 'Coalesce', 'Collate', 'Greatest', 'JSONObject', 'Least', 'NullIf',
|
||||
# datetime
|
||||
"Extract",
|
||||
"ExtractDay",
|
||||
"ExtractHour",
|
||||
"ExtractMinute",
|
||||
"ExtractMonth",
|
||||
"ExtractQuarter",
|
||||
"ExtractSecond",
|
||||
"ExtractWeek",
|
||||
"ExtractIsoWeekDay",
|
||||
"ExtractWeekDay",
|
||||
"ExtractIsoYear",
|
||||
"ExtractYear",
|
||||
"Now",
|
||||
"Trunc",
|
||||
"TruncDate",
|
||||
"TruncDay",
|
||||
"TruncHour",
|
||||
"TruncMinute",
|
||||
"TruncMonth",
|
||||
"TruncQuarter",
|
||||
"TruncSecond",
|
||||
"TruncTime",
|
||||
"TruncWeek",
|
||||
"TruncYear",
|
||||
'Extract', 'ExtractDay', 'ExtractHour', 'ExtractMinute', 'ExtractMonth',
|
||||
'ExtractQuarter', 'ExtractSecond', 'ExtractWeek', 'ExtractIsoWeekDay',
|
||||
'ExtractWeekDay', 'ExtractIsoYear', 'ExtractYear', 'Now', 'Trunc',
|
||||
'TruncDate', 'TruncDay', 'TruncHour', 'TruncMinute', 'TruncMonth',
|
||||
'TruncQuarter', 'TruncSecond', 'TruncTime', 'TruncWeek', 'TruncYear',
|
||||
# math
|
||||
"Abs",
|
||||
"ACos",
|
||||
"ASin",
|
||||
"ATan",
|
||||
"ATan2",
|
||||
"Ceil",
|
||||
"Cos",
|
||||
"Cot",
|
||||
"Degrees",
|
||||
"Exp",
|
||||
"Floor",
|
||||
"Ln",
|
||||
"Log",
|
||||
"Mod",
|
||||
"Pi",
|
||||
"Power",
|
||||
"Radians",
|
||||
"Random",
|
||||
"Round",
|
||||
"Sign",
|
||||
"Sin",
|
||||
"Sqrt",
|
||||
"Tan",
|
||||
'Abs', 'ACos', 'ASin', 'ATan', 'ATan2', 'Ceil', 'Cos', 'Cot', 'Degrees',
|
||||
'Exp', 'Floor', 'Ln', 'Log', 'Mod', 'Pi', 'Power', 'Radians', 'Random',
|
||||
'Round', 'Sign', 'Sin', 'Sqrt', 'Tan',
|
||||
# text
|
||||
"MD5",
|
||||
"SHA1",
|
||||
"SHA224",
|
||||
"SHA256",
|
||||
"SHA384",
|
||||
"SHA512",
|
||||
"Chr",
|
||||
"Concat",
|
||||
"ConcatPair",
|
||||
"Left",
|
||||
"Length",
|
||||
"Lower",
|
||||
"LPad",
|
||||
"LTrim",
|
||||
"Ord",
|
||||
"Repeat",
|
||||
"Replace",
|
||||
"Reverse",
|
||||
"Right",
|
||||
"RPad",
|
||||
"RTrim",
|
||||
"StrIndex",
|
||||
"Substr",
|
||||
"Trim",
|
||||
"Upper",
|
||||
'MD5', 'SHA1', 'SHA224', 'SHA256', 'SHA384', 'SHA512', 'Chr', 'Concat',
|
||||
'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Ord', 'Repeat',
|
||||
'Replace', 'Reverse', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr',
|
||||
'Trim', 'Upper',
|
||||
# window
|
||||
"CumeDist",
|
||||
"DenseRank",
|
||||
"FirstValue",
|
||||
"Lag",
|
||||
"LastValue",
|
||||
"Lead",
|
||||
"NthValue",
|
||||
"Ntile",
|
||||
"PercentRank",
|
||||
"Rank",
|
||||
"RowNumber",
|
||||
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
|
||||
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
|
||||
]
|
||||
|
||||
@@ -7,115 +7,84 @@ from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
class Cast(Func):
|
||||
"""Coerce an expression to a new field type."""
|
||||
|
||||
function = "CAST"
|
||||
template = "%(function)s(%(expressions)s AS %(db_type)s)"
|
||||
function = 'CAST'
|
||||
template = '%(function)s(%(expressions)s AS %(db_type)s)'
|
||||
|
||||
def __init__(self, expression, output_field):
|
||||
super().__init__(expression, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["db_type"] = self.output_field.cast_db_type(connection)
|
||||
extra_context['db_type'] = self.output_field.cast_db_type(connection)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
db_type = self.output_field.db_type(connection)
|
||||
if db_type in {"datetime", "time"}:
|
||||
if db_type in {'datetime', 'time'}:
|
||||
# Use strftime as datetime/time don't keep fractional seconds.
|
||||
template = "strftime(%%s, %(expressions)s)"
|
||||
sql, params = super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
|
||||
template = 'strftime(%%s, %(expressions)s)'
|
||||
sql, params = super().as_sql(compiler, connection, template=template, **extra_context)
|
||||
format_string = '%H:%M:%f' if db_type == 'time' else '%Y-%m-%d %H:%M:%f'
|
||||
params.insert(0, format_string)
|
||||
return sql, params
|
||||
elif db_type == "date":
|
||||
template = "date(%(expressions)s)"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
elif db_type == 'date':
|
||||
template = 'date(%(expressions)s)'
|
||||
return super().as_sql(compiler, connection, template=template, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
template = None
|
||||
output_type = self.output_field.get_internal_type()
|
||||
# MySQL doesn't support explicit cast to float.
|
||||
if output_type == "FloatField":
|
||||
template = "(%(expressions)s + 0.0)"
|
||||
if output_type == 'FloatField':
|
||||
template = '(%(expressions)s + 0.0)'
|
||||
# MariaDB doesn't support explicit cast to JSON.
|
||||
elif output_type == "JSONField" and connection.mysql_is_mariadb:
|
||||
elif output_type == 'JSONField' and connection.mysql_is_mariadb:
|
||||
template = "JSON_EXTRACT(%(expressions)s, '$')"
|
||||
return self.as_sql(compiler, connection, template=template, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# CAST would be valid too, but the :: shortcut syntax is more readable.
|
||||
# 'expressions' is wrapped in parentheses in case it's a complex
|
||||
# expression.
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="(%(expressions)s)::%(db_type)s",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == "JSONField":
|
||||
if self.output_field.get_internal_type() == 'JSONField':
|
||||
# Oracle doesn't support explicit cast to JSON.
|
||||
template = "JSON_QUERY(%(expressions)s, '$')"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
return super().as_sql(compiler, connection, template=template, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Coalesce(Func):
|
||||
"""Return, from left to right, the first non-null expression."""
|
||||
|
||||
function = "COALESCE"
|
||||
function = 'COALESCE'
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Coalesce must take at least two expressions")
|
||||
raise ValueError('Coalesce must take at least two expressions')
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def empty_result_set_value(self):
|
||||
for expression in self.get_source_expressions():
|
||||
result = expression.empty_result_set_value
|
||||
if result is NotImplemented or result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
|
||||
# so convert all fields to NCLOB when that type is expected.
|
||||
if self.output_field.get_internal_type() == "TextField":
|
||||
if self.output_field.get_internal_type() == 'TextField':
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Func(expression, function="TO_NCLOB")
|
||||
for expression in self.get_source_expressions()
|
||||
]
|
||||
)
|
||||
clone.set_source_expressions([
|
||||
Func(expression, function='TO_NCLOB') for expression in self.get_source_expressions()
|
||||
])
|
||||
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Collate(Func):
|
||||
function = "COLLATE"
|
||||
template = "%(expressions)s %(function)s %(collation)s"
|
||||
# Inspired from
|
||||
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||
collation_re = _lazy_re_compile(r"^[\w\-]+$")
|
||||
function = 'COLLATE'
|
||||
template = '%(expressions)s %(function)s %(collation)s'
|
||||
# Inspired from https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||
collation_re = _lazy_re_compile(r'^[\w\-]+$')
|
||||
|
||||
def __init__(self, expression, collation):
|
||||
if not (collation and self.collation_re.match(collation)):
|
||||
raise ValueError("Invalid collation name: %r." % collation)
|
||||
raise ValueError('Invalid collation name: %r.' % collation)
|
||||
self.collation = collation
|
||||
super().__init__(expression)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
|
||||
extra_context.setdefault('collation', connection.ops.quote_name(self.collation))
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
@@ -127,21 +96,20 @@ class Greatest(Func):
|
||||
On PostgreSQL, the maximum not-null expression is returned.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
|
||||
"""
|
||||
|
||||
function = "GREATEST"
|
||||
function = 'GREATEST'
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Greatest must take at least two expressions")
|
||||
raise ValueError('Greatest must take at least two expressions')
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MAX function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
|
||||
return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
|
||||
|
||||
|
||||
class JSONObject(Func):
|
||||
function = "JSON_OBJECT"
|
||||
function = 'JSON_OBJECT'
|
||||
output_field = JSONField()
|
||||
|
||||
def __init__(self, **fields):
|
||||
@@ -153,7 +121,7 @@ class JSONObject(Func):
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.has_json_object_function:
|
||||
raise NotSupportedError(
|
||||
"JSONObject() is not supported on this database backend."
|
||||
'JSONObject() is not supported on this database backend.'
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
@@ -161,21 +129,21 @@ class JSONObject(Func):
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="JSONB_BUILD_OBJECT",
|
||||
function='JSONB_BUILD_OBJECT',
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
class ArgJoiner:
|
||||
def join(self, args):
|
||||
args = [" VALUE ".join(arg) for arg in zip(args[::2], args[1::2])]
|
||||
return ", ".join(args)
|
||||
args = [' VALUE '.join(arg) for arg in zip(args[::2], args[1::2])]
|
||||
return ', '.join(args)
|
||||
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
arg_joiner=ArgJoiner(),
|
||||
template="%(function)s(%(expressions)s RETURNING CLOB)",
|
||||
template='%(function)s(%(expressions)s RETURNING CLOB)',
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
@@ -188,25 +156,24 @@ class Least(Func):
|
||||
On PostgreSQL, return the minimum not-null expression.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, return null.
|
||||
"""
|
||||
|
||||
function = "LEAST"
|
||||
function = 'LEAST'
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Least must take at least two expressions")
|
||||
raise ValueError('Least must take at least two expressions')
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MIN function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
|
||||
return super().as_sqlite(compiler, connection, function='MIN', **extra_context)
|
||||
|
||||
|
||||
class NullIf(Func):
|
||||
function = "NULLIF"
|
||||
function = 'NULLIF'
|
||||
arity = 2
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression1 = self.get_source_expressions()[0]
|
||||
if isinstance(expression1, Value) and expression1.value is None:
|
||||
raise ValueError("Oracle does not allow Value(None) for expression1.")
|
||||
raise ValueError('Oracle does not allow Value(None) for expression1.')
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
@@ -3,20 +3,10 @@ from datetime import datetime
|
||||
from django.conf import settings
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import (
|
||||
DateField,
|
||||
DateTimeField,
|
||||
DurationField,
|
||||
Field,
|
||||
IntegerField,
|
||||
TimeField,
|
||||
DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
Transform,
|
||||
YearExact,
|
||||
YearGt,
|
||||
YearGte,
|
||||
YearLt,
|
||||
YearLte,
|
||||
Transform, YearExact, YearGt, YearGte, YearLt, YearLte,
|
||||
)
|
||||
from django.utils import timezone
|
||||
|
||||
@@ -46,7 +36,7 @@ class Extract(TimezoneMixin, Transform):
|
||||
if self.lookup_name is None:
|
||||
self.lookup_name = lookup_name
|
||||
if self.lookup_name is None:
|
||||
raise ValueError("lookup_name must be provided")
|
||||
raise ValueError('lookup_name must be provided')
|
||||
self.tzinfo = tzinfo
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
@@ -57,16 +47,14 @@ class Extract(TimezoneMixin, Transform):
|
||||
tzname = self.get_tzname()
|
||||
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
raise ValueError('tzinfo can only be used with DateTimeField.')
|
||||
elif isinstance(lhs_output_field, DateField):
|
||||
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
|
||||
elif isinstance(lhs_output_field, TimeField):
|
||||
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
||||
elif isinstance(lhs_output_field, DurationField):
|
||||
if not connection.features.has_native_duration_field:
|
||||
raise ValueError(
|
||||
"Extract requires native DurationField database support."
|
||||
)
|
||||
raise ValueError('Extract requires native DurationField database support.')
|
||||
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
||||
else:
|
||||
# resolve_expression has already validated the output_field so this
|
||||
@@ -74,38 +62,22 @@ class Extract(TimezoneMixin, Transform):
|
||||
assert False, "Tried to Extract from an invalid type."
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
field = getattr(copy.lhs, "output_field", None)
|
||||
if field is None:
|
||||
return copy
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
field = copy.lhs.output_field
|
||||
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
|
||||
raise ValueError(
|
||||
"Extract input expression must be DateField, DateTimeField, "
|
||||
"TimeField, or DurationField."
|
||||
'Extract input expression must be DateField, DateTimeField, '
|
||||
'TimeField, or DurationField.'
|
||||
)
|
||||
# Passing dates to functions expecting datetimes is most likely a mistake.
|
||||
if type(field) == DateField and copy.lookup_name in (
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
):
|
||||
if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
|
||||
raise ValueError(
|
||||
"Cannot extract time component '%s' from DateField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
"Cannot extract time component '%s' from DateField '%s'. " % (copy.lookup_name, field.name)
|
||||
)
|
||||
if isinstance(field, DurationField) and copy.lookup_name in (
|
||||
"year",
|
||||
"iso_year",
|
||||
"month",
|
||||
"week",
|
||||
"week_day",
|
||||
"iso_week_day",
|
||||
"quarter",
|
||||
if (
|
||||
isinstance(field, DurationField) and
|
||||
copy.lookup_name in ('year', 'iso_year', 'month', 'week', 'week_day', 'iso_week_day', 'quarter')
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract component '%s' from DurationField '%s'."
|
||||
@@ -115,21 +87,20 @@ class Extract(TimezoneMixin, Transform):
|
||||
|
||||
|
||||
class ExtractYear(Extract):
|
||||
lookup_name = "year"
|
||||
lookup_name = 'year'
|
||||
|
||||
|
||||
class ExtractIsoYear(Extract):
|
||||
"""Return the ISO-8601 week-numbering year."""
|
||||
|
||||
lookup_name = "iso_year"
|
||||
lookup_name = 'iso_year'
|
||||
|
||||
|
||||
class ExtractMonth(Extract):
|
||||
lookup_name = "month"
|
||||
lookup_name = 'month'
|
||||
|
||||
|
||||
class ExtractDay(Extract):
|
||||
lookup_name = "day"
|
||||
lookup_name = 'day'
|
||||
|
||||
|
||||
class ExtractWeek(Extract):
|
||||
@@ -137,8 +108,7 @@ class ExtractWeek(Extract):
|
||||
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
|
||||
week.
|
||||
"""
|
||||
|
||||
lookup_name = "week"
|
||||
lookup_name = 'week'
|
||||
|
||||
|
||||
class ExtractWeekDay(Extract):
|
||||
@@ -147,30 +117,28 @@ class ExtractWeekDay(Extract):
|
||||
|
||||
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
|
||||
"""
|
||||
|
||||
lookup_name = "week_day"
|
||||
lookup_name = 'week_day'
|
||||
|
||||
|
||||
class ExtractIsoWeekDay(Extract):
|
||||
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
|
||||
|
||||
lookup_name = "iso_week_day"
|
||||
lookup_name = 'iso_week_day'
|
||||
|
||||
|
||||
class ExtractQuarter(Extract):
|
||||
lookup_name = "quarter"
|
||||
lookup_name = 'quarter'
|
||||
|
||||
|
||||
class ExtractHour(Extract):
|
||||
lookup_name = "hour"
|
||||
lookup_name = 'hour'
|
||||
|
||||
|
||||
class ExtractMinute(Extract):
|
||||
lookup_name = "minute"
|
||||
lookup_name = 'minute'
|
||||
|
||||
|
||||
class ExtractSecond(Extract):
|
||||
lookup_name = "second"
|
||||
lookup_name = 'second'
|
||||
|
||||
|
||||
DateField.register_lookup(ExtractYear)
|
||||
@@ -204,32 +172,21 @@ ExtractIsoYear.register_lookup(YearLte)
|
||||
|
||||
|
||||
class Now(Func):
|
||||
template = "CURRENT_TIMESTAMP"
|
||||
template = 'CURRENT_TIMESTAMP'
|
||||
output_field = DateTimeField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
||||
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
||||
# other databases.
|
||||
return self.as_sql(
|
||||
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
|
||||
)
|
||||
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
|
||||
|
||||
|
||||
class TruncBase(TimezoneMixin, Transform):
|
||||
kind = None
|
||||
tzinfo = None
|
||||
|
||||
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
||||
# argument.
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
is_dst=timezone.NOT_PASSED,
|
||||
**extra,
|
||||
):
|
||||
def __init__(self, expression, output_field=None, tzinfo=None, is_dst=None, **extra):
|
||||
self.tzinfo = tzinfo
|
||||
self.is_dst = is_dst
|
||||
super().__init__(expression, output_field=output_field, **extra)
|
||||
@@ -240,7 +197,7 @@ class TruncBase(TimezoneMixin, Transform):
|
||||
if isinstance(self.lhs.output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
raise ValueError('tzinfo can only be used with DateTimeField.')
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
|
||||
elif isinstance(self.output_field, DateField):
|
||||
@@ -248,66 +205,36 @@ class TruncBase(TimezoneMixin, Transform):
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Trunc only valid on DateField, TimeField, or DateTimeField."
|
||||
)
|
||||
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
|
||||
return sql, inner_params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
field = copy.lhs.output_field
|
||||
# DateTimeField is a subclass of DateField so this works for both.
|
||||
if not isinstance(field, (DateField, TimeField)):
|
||||
raise TypeError(
|
||||
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
|
||||
)
|
||||
assert isinstance(field, (DateField, TimeField)), (
|
||||
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
|
||||
)
|
||||
# If self.output_field was None, then accessing the field will trigger
|
||||
# the resolver to assign it to self.lhs.output_field.
|
||||
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
|
||||
raise ValueError(
|
||||
"output_field must be either DateField, TimeField, or DateTimeField"
|
||||
)
|
||||
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
|
||||
# Passing dates or times to functions expecting datetimes is most
|
||||
# likely a mistake.
|
||||
class_output_field = (
|
||||
self.__class__.output_field
|
||||
if isinstance(self.__class__.output_field, Field)
|
||||
else None
|
||||
)
|
||||
class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
|
||||
output_field = class_output_field or copy.output_field
|
||||
has_explicit_output_field = (
|
||||
class_output_field or field.__class__ is not copy.output_field.__class__
|
||||
)
|
||||
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
|
||||
if type(field) == DateField and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("hour", "minute", "second", "time")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate DateField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField",
|
||||
)
|
||||
)
|
||||
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
|
||||
raise ValueError("Cannot truncate DateField '%s' to %s. " % (
|
||||
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
||||
))
|
||||
elif isinstance(field, TimeField) and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("year", "quarter", "month", "week", "day", "date")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate TimeField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField",
|
||||
)
|
||||
)
|
||||
isinstance(output_field, DateTimeField) or
|
||||
copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')):
|
||||
raise ValueError("Cannot truncate TimeField '%s' to %s. " % (
|
||||
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
||||
))
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
@@ -319,8 +246,8 @@ class TruncBase(TimezoneMixin, Transform):
|
||||
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
|
||||
elif not connection.features.has_zoneinfo_database:
|
||||
raise ValueError(
|
||||
"Database returned an invalid datetime value. Are time "
|
||||
"zone definitions for your database installed?"
|
||||
'Database returned an invalid datetime value. Are time '
|
||||
'zone definitions for your database installed?'
|
||||
)
|
||||
elif isinstance(value, datetime):
|
||||
if value is None:
|
||||
@@ -334,48 +261,38 @@ class TruncBase(TimezoneMixin, Transform):
|
||||
|
||||
class Trunc(TruncBase):
|
||||
|
||||
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
||||
# argument.
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
kind,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
is_dst=timezone.NOT_PASSED,
|
||||
**extra,
|
||||
):
|
||||
def __init__(self, expression, kind, output_field=None, tzinfo=None, is_dst=None, **extra):
|
||||
self.kind = kind
|
||||
super().__init__(
|
||||
expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra
|
||||
expression, output_field=output_field, tzinfo=tzinfo,
|
||||
is_dst=is_dst, **extra
|
||||
)
|
||||
|
||||
|
||||
class TruncYear(TruncBase):
|
||||
kind = "year"
|
||||
kind = 'year'
|
||||
|
||||
|
||||
class TruncQuarter(TruncBase):
|
||||
kind = "quarter"
|
||||
kind = 'quarter'
|
||||
|
||||
|
||||
class TruncMonth(TruncBase):
|
||||
kind = "month"
|
||||
kind = 'month'
|
||||
|
||||
|
||||
class TruncWeek(TruncBase):
|
||||
"""Truncate to midnight on the Monday of the week."""
|
||||
|
||||
kind = "week"
|
||||
kind = 'week'
|
||||
|
||||
|
||||
class TruncDay(TruncBase):
|
||||
kind = "day"
|
||||
kind = 'day'
|
||||
|
||||
|
||||
class TruncDate(TruncBase):
|
||||
kind = "date"
|
||||
lookup_name = "date"
|
||||
kind = 'date'
|
||||
lookup_name = 'date'
|
||||
output_field = DateField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
@@ -387,8 +304,8 @@ class TruncDate(TruncBase):
|
||||
|
||||
|
||||
class TruncTime(TruncBase):
|
||||
kind = "time"
|
||||
lookup_name = "time"
|
||||
kind = 'time'
|
||||
lookup_name = 'time'
|
||||
output_field = TimeField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
@@ -400,15 +317,15 @@ class TruncTime(TruncBase):
|
||||
|
||||
|
||||
class TruncHour(TruncBase):
|
||||
kind = "hour"
|
||||
kind = 'hour'
|
||||
|
||||
|
||||
class TruncMinute(TruncBase):
|
||||
kind = "minute"
|
||||
kind = 'minute'
|
||||
|
||||
|
||||
class TruncSecond(TruncBase):
|
||||
kind = "second"
|
||||
kind = 'second'
|
||||
|
||||
|
||||
DateTimeField.register_lookup(TruncDate)
|
||||
|
||||
@@ -1,43 +1,40 @@
|
||||
import math
|
||||
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
from django.db.models.functions import Cast
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDecimalInputMixin,
|
||||
NumericOutputFieldMixin,
|
||||
FixDecimalInputMixin, NumericOutputFieldMixin,
|
||||
)
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
|
||||
class Abs(Transform):
|
||||
function = "ABS"
|
||||
lookup_name = "abs"
|
||||
function = 'ABS'
|
||||
lookup_name = 'abs'
|
||||
|
||||
|
||||
class ACos(NumericOutputFieldMixin, Transform):
|
||||
function = "ACOS"
|
||||
lookup_name = "acos"
|
||||
function = 'ACOS'
|
||||
lookup_name = 'acos'
|
||||
|
||||
|
||||
class ASin(NumericOutputFieldMixin, Transform):
|
||||
function = "ASIN"
|
||||
lookup_name = "asin"
|
||||
function = 'ASIN'
|
||||
lookup_name = 'asin'
|
||||
|
||||
|
||||
class ATan(NumericOutputFieldMixin, Transform):
|
||||
function = "ATAN"
|
||||
lookup_name = "atan"
|
||||
function = 'ATAN'
|
||||
lookup_name = 'atan'
|
||||
|
||||
|
||||
class ATan2(NumericOutputFieldMixin, Func):
|
||||
function = "ATAN2"
|
||||
function = 'ATAN2'
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(
|
||||
connection.ops, "spatialite", False
|
||||
) or connection.ops.spatial_version >= (5, 0, 0):
|
||||
if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version >= (5, 0, 0):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually ATan2(y, x), returning the inverse tangent
|
||||
# of y / x, but it's ATan2(x, y) on SpatiaLite < 5.0.0.
|
||||
@@ -45,74 +42,67 @@ class ATan2(NumericOutputFieldMixin, Func):
|
||||
# arguments are mixed between integer and float or decimal.
|
||||
# https://www.gaia-gis.it/fossil/libspatialite/tktview?name=0f72cca3a2
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Cast(expression, FloatField())
|
||||
if isinstance(expression.output_field, IntegerField)
|
||||
else expression
|
||||
for expression in self.get_source_expressions()[::-1]
|
||||
]
|
||||
)
|
||||
clone.set_source_expressions([
|
||||
Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
|
||||
else expression for expression in self.get_source_expressions()[::-1]
|
||||
])
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Ceil(Transform):
|
||||
function = "CEILING"
|
||||
lookup_name = "ceil"
|
||||
function = 'CEILING'
|
||||
lookup_name = 'ceil'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="CEIL", **extra_context)
|
||||
return super().as_sql(compiler, connection, function='CEIL', **extra_context)
|
||||
|
||||
|
||||
class Cos(NumericOutputFieldMixin, Transform):
|
||||
function = "COS"
|
||||
lookup_name = "cos"
|
||||
function = 'COS'
|
||||
lookup_name = 'cos'
|
||||
|
||||
|
||||
class Cot(NumericOutputFieldMixin, Transform):
|
||||
function = "COT"
|
||||
lookup_name = "cot"
|
||||
function = 'COT'
|
||||
lookup_name = 'cot'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, template="(1 / TAN(%(expressions)s))", **extra_context
|
||||
)
|
||||
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
|
||||
|
||||
|
||||
class Degrees(NumericOutputFieldMixin, Transform):
|
||||
function = "DEGREES"
|
||||
lookup_name = "degrees"
|
||||
function = 'DEGREES'
|
||||
lookup_name = 'degrees'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="((%%(expressions)s) * 180 / %s)" % math.pi,
|
||||
**extra_context,
|
||||
compiler, connection,
|
||||
template='((%%(expressions)s) * 180 / %s)' % math.pi,
|
||||
**extra_context
|
||||
)
|
||||
|
||||
|
||||
class Exp(NumericOutputFieldMixin, Transform):
|
||||
function = "EXP"
|
||||
lookup_name = "exp"
|
||||
function = 'EXP'
|
||||
lookup_name = 'exp'
|
||||
|
||||
|
||||
class Floor(Transform):
|
||||
function = "FLOOR"
|
||||
lookup_name = "floor"
|
||||
function = 'FLOOR'
|
||||
lookup_name = 'floor'
|
||||
|
||||
|
||||
class Ln(NumericOutputFieldMixin, Transform):
|
||||
function = "LN"
|
||||
lookup_name = "ln"
|
||||
function = 'LN'
|
||||
lookup_name = 'ln'
|
||||
|
||||
|
||||
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = "LOG"
|
||||
function = 'LOG'
|
||||
arity = 2
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if not getattr(connection.ops, "spatialite", False):
|
||||
if not getattr(connection.ops, 'spatialite', False):
|
||||
return self.as_sql(compiler, connection)
|
||||
# This function is usually Log(b, x) returning the logarithm of x to
|
||||
# the base b, but on SpatiaLite it's Log(x, b).
|
||||
@@ -122,91 +112,72 @@ class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
|
||||
|
||||
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||
function = "MOD"
|
||||
function = 'MOD'
|
||||
arity = 2
|
||||
|
||||
|
||||
class Pi(NumericOutputFieldMixin, Func):
|
||||
function = "PI"
|
||||
function = 'PI'
|
||||
arity = 0
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, template=str(math.pi), **extra_context
|
||||
)
|
||||
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
|
||||
|
||||
|
||||
class Power(NumericOutputFieldMixin, Func):
|
||||
function = "POWER"
|
||||
function = 'POWER'
|
||||
arity = 2
|
||||
|
||||
|
||||
class Radians(NumericOutputFieldMixin, Transform):
|
||||
function = "RADIANS"
|
||||
lookup_name = "radians"
|
||||
function = 'RADIANS'
|
||||
lookup_name = 'radians'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="((%%(expressions)s) * %s / 180)" % math.pi,
|
||||
**extra_context,
|
||||
compiler, connection,
|
||||
template='((%%(expressions)s) * %s / 180)' % math.pi,
|
||||
**extra_context
|
||||
)
|
||||
|
||||
|
||||
class Random(NumericOutputFieldMixin, Func):
|
||||
function = "RANDOM"
|
||||
function = 'RANDOM'
|
||||
arity = 0
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
||||
return super().as_sql(compiler, connection, function='RAND', **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, function="DBMS_RANDOM.VALUE", **extra_context
|
||||
)
|
||||
return super().as_sql(compiler, connection, function='DBMS_RANDOM.VALUE', **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="RAND", **extra_context)
|
||||
return super().as_sql(compiler, connection, function='RAND', **extra_context)
|
||||
|
||||
def get_group_by_cols(self, alias=None):
|
||||
return []
|
||||
|
||||
|
||||
class Round(FixDecimalInputMixin, Transform):
|
||||
function = "ROUND"
|
||||
lookup_name = "round"
|
||||
arity = None # Override Transform's arity=1 to enable passing precision.
|
||||
|
||||
def __init__(self, expression, precision=0, **extra):
|
||||
super().__init__(expression, precision, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
precision = self.get_source_expressions()[1]
|
||||
if isinstance(precision, Value) and precision.value < 0:
|
||||
raise ValueError("SQLite does not support negative precision.")
|
||||
return super().as_sqlite(compiler, connection, **extra_context)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
source = self.get_source_expressions()[0]
|
||||
return source.output_field
|
||||
class Round(Transform):
|
||||
function = 'ROUND'
|
||||
lookup_name = 'round'
|
||||
|
||||
|
||||
class Sign(Transform):
|
||||
function = "SIGN"
|
||||
lookup_name = "sign"
|
||||
function = 'SIGN'
|
||||
lookup_name = 'sign'
|
||||
|
||||
|
||||
class Sin(NumericOutputFieldMixin, Transform):
|
||||
function = "SIN"
|
||||
lookup_name = "sin"
|
||||
function = 'SIN'
|
||||
lookup_name = 'sin'
|
||||
|
||||
|
||||
class Sqrt(NumericOutputFieldMixin, Transform):
|
||||
function = "SQRT"
|
||||
lookup_name = "sqrt"
|
||||
function = 'SQRT'
|
||||
lookup_name = 'sqrt'
|
||||
|
||||
|
||||
class Tan(NumericOutputFieldMixin, Transform):
|
||||
function = "TAN"
|
||||
lookup_name = "tan"
|
||||
function = 'TAN'
|
||||
lookup_name = 'tan'
|
||||
|
||||
@@ -5,6 +5,7 @@ from django.db.models.functions import Cast
|
||||
|
||||
|
||||
class FixDecimalInputMixin:
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
||||
# following function signatures:
|
||||
@@ -12,42 +13,36 @@ class FixDecimalInputMixin:
|
||||
# - MOD(double, double)
|
||||
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Cast(expression, output_field)
|
||||
if isinstance(expression.output_field, FloatField)
|
||||
else expression
|
||||
for expression in self.get_source_expressions()
|
||||
]
|
||||
)
|
||||
clone.set_source_expressions([
|
||||
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
|
||||
else expression for expression in self.get_source_expressions()
|
||||
])
|
||||
return clone.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class FixDurationInputMixin:
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
sql, params = super().as_sql(compiler, connection, **extra_context)
|
||||
if self.output_field.get_internal_type() == "DurationField":
|
||||
sql = "CAST(%s AS SIGNED)" % sql
|
||||
if self.output_field.get_internal_type() == 'DurationField':
|
||||
sql = 'CAST(%s AS SIGNED)' % sql
|
||||
return sql, params
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == "DurationField":
|
||||
if self.output_field.get_internal_type() == 'DurationField':
|
||||
expression = self.get_source_expressions()[0]
|
||||
options = self._get_repr_options()
|
||||
from django.db.backends.oracle.functions import (
|
||||
IntervalToSeconds,
|
||||
SecondsToInterval,
|
||||
IntervalToSeconds, SecondsToInterval,
|
||||
)
|
||||
|
||||
return compiler.compile(
|
||||
SecondsToInterval(
|
||||
self.__class__(IntervalToSeconds(expression), **options)
|
||||
)
|
||||
SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options))
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class NumericOutputFieldMixin:
|
||||
|
||||
def _resolve_output_field(self):
|
||||
source_fields = self.get_source_fields()
|
||||
if any(isinstance(s, DecimalField) for s in source_fields):
|
||||
|
||||
@@ -10,7 +10,7 @@ class MySQLSHA2Mixin:
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="SHA2(%%(expressions)s, %s)" % self.function[3:],
|
||||
template='SHA2(%%(expressions)s, %s)' % self.function[3:],
|
||||
**extra_content,
|
||||
)
|
||||
|
||||
@@ -40,28 +40,25 @@ class PostgreSQLSHAMixin:
|
||||
|
||||
|
||||
class Chr(Transform):
|
||||
function = "CHR"
|
||||
lookup_name = "chr"
|
||||
function = 'CHR'
|
||||
lookup_name = 'chr'
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="CHAR",
|
||||
template="%(function)s(%(expressions)s USING utf16)",
|
||||
**extra_context,
|
||||
compiler, connection, function='CHAR',
|
||||
template='%(function)s(%(expressions)s USING utf16)',
|
||||
**extra_context
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="%(function)s(%(expressions)s USING NCHAR_CS)",
|
||||
**extra_context,
|
||||
compiler, connection,
|
||||
template='%(function)s(%(expressions)s USING NCHAR_CS)',
|
||||
**extra_context
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="CHAR", **extra_context)
|
||||
return super().as_sql(compiler, connection, function='CHAR', **extra_context)
|
||||
|
||||
|
||||
class ConcatPair(Func):
|
||||
@@ -69,38 +66,29 @@ class ConcatPair(Func):
|
||||
Concatenate two arguments together. This is used by `Concat` because not
|
||||
all backend databases support more than two arguments.
|
||||
"""
|
||||
|
||||
function = "CONCAT"
|
||||
function = 'CONCAT'
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
coalesced = self.coalesce()
|
||||
return super(ConcatPair, coalesced).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="%(expressions)s",
|
||||
arg_joiner=" || ",
|
||||
**extra_context,
|
||||
compiler, connection, template='%(expressions)s', arg_joiner=' || ',
|
||||
**extra_context
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
function="CONCAT_WS",
|
||||
compiler, connection, function='CONCAT_WS',
|
||||
template="%(function)s('', %(expressions)s)",
|
||||
**extra_context,
|
||||
**extra_context
|
||||
)
|
||||
|
||||
def coalesce(self):
|
||||
# null on either side results in null for expression, wrap with coalesce
|
||||
c = self.copy()
|
||||
c.set_source_expressions(
|
||||
[
|
||||
Coalesce(expression, Value(""))
|
||||
for expression in c.get_source_expressions()
|
||||
]
|
||||
)
|
||||
c.set_source_expressions([
|
||||
Coalesce(expression, Value('')) for expression in c.get_source_expressions()
|
||||
])
|
||||
return c
|
||||
|
||||
|
||||
@@ -110,13 +98,12 @@ class Concat(Func):
|
||||
null expression when any arguments are null will wrap each argument in
|
||||
coalesce functions to ensure a non-null result.
|
||||
"""
|
||||
|
||||
function = None
|
||||
template = "%(expressions)s"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Concat must take at least two expressions")
|
||||
raise ValueError('Concat must take at least two expressions')
|
||||
paired = self._paired(expressions)
|
||||
super().__init__(paired, **extra)
|
||||
|
||||
@@ -130,7 +117,7 @@ class Concat(Func):
|
||||
|
||||
|
||||
class Left(Func):
|
||||
function = "LEFT"
|
||||
function = 'LEFT'
|
||||
arity = 2
|
||||
output_field = CharField()
|
||||
|
||||
@@ -139,7 +126,7 @@ class Left(Func):
|
||||
expression: the name of a field, or an expression returning a string
|
||||
length: the number of characters to return from the start of the string
|
||||
"""
|
||||
if not hasattr(length, "resolve_expression"):
|
||||
if not hasattr(length, 'resolve_expression'):
|
||||
if length < 1:
|
||||
raise ValueError("'length' must be greater than 0.")
|
||||
super().__init__(expression, length, **extra)
|
||||
@@ -156,68 +143,57 @@ class Left(Func):
|
||||
|
||||
class Length(Transform):
|
||||
"""Return the number of characters in the expression."""
|
||||
|
||||
function = "LENGTH"
|
||||
lookup_name = "length"
|
||||
function = 'LENGTH'
|
||||
lookup_name = 'length'
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(
|
||||
compiler, connection, function="CHAR_LENGTH", **extra_context
|
||||
)
|
||||
return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
|
||||
|
||||
|
||||
class Lower(Transform):
|
||||
function = "LOWER"
|
||||
lookup_name = "lower"
|
||||
function = 'LOWER'
|
||||
lookup_name = 'lower'
|
||||
|
||||
|
||||
class LPad(Func):
|
||||
function = "LPAD"
|
||||
function = 'LPAD'
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, length, fill_text=Value(" "), **extra):
|
||||
if (
|
||||
not hasattr(length, "resolve_expression")
|
||||
and length is not None
|
||||
and length < 0
|
||||
):
|
||||
def __init__(self, expression, length, fill_text=Value(' '), **extra):
|
||||
if not hasattr(length, 'resolve_expression') and length is not None and length < 0:
|
||||
raise ValueError("'length' must be greater or equal to 0.")
|
||||
super().__init__(expression, length, fill_text, **extra)
|
||||
|
||||
|
||||
class LTrim(Transform):
|
||||
function = "LTRIM"
|
||||
lookup_name = "ltrim"
|
||||
function = 'LTRIM'
|
||||
lookup_name = 'ltrim'
|
||||
|
||||
|
||||
class MD5(OracleHashMixin, Transform):
|
||||
function = "MD5"
|
||||
lookup_name = "md5"
|
||||
function = 'MD5'
|
||||
lookup_name = 'md5'
|
||||
|
||||
|
||||
class Ord(Transform):
|
||||
function = "ASCII"
|
||||
lookup_name = "ord"
|
||||
function = 'ASCII'
|
||||
lookup_name = 'ord'
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="ORD", **extra_context)
|
||||
return super().as_sql(compiler, connection, function='ORD', **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="UNICODE", **extra_context)
|
||||
return super().as_sql(compiler, connection, function='UNICODE', **extra_context)
|
||||
|
||||
|
||||
class Repeat(Func):
|
||||
function = "REPEAT"
|
||||
function = 'REPEAT'
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, number, **extra):
|
||||
if (
|
||||
not hasattr(number, "resolve_expression")
|
||||
and number is not None
|
||||
and number < 0
|
||||
):
|
||||
if not hasattr(number, 'resolve_expression') and number is not None and number < 0:
|
||||
raise ValueError("'number' must be greater or equal to 0.")
|
||||
super().__init__(expression, number, **extra)
|
||||
|
||||
@@ -229,76 +205,73 @@ class Repeat(Func):
|
||||
|
||||
|
||||
class Replace(Func):
|
||||
function = "REPLACE"
|
||||
function = 'REPLACE'
|
||||
|
||||
def __init__(self, expression, text, replacement=Value(""), **extra):
|
||||
def __init__(self, expression, text, replacement=Value(''), **extra):
|
||||
super().__init__(expression, text, replacement, **extra)
|
||||
|
||||
|
||||
class Reverse(Transform):
|
||||
function = "REVERSE"
|
||||
lookup_name = "reverse"
|
||||
function = 'REVERSE'
|
||||
lookup_name = 'reverse'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# REVERSE in Oracle is undocumented and doesn't support multi-byte
|
||||
# strings. Use a special subquery instead.
|
||||
return super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
compiler, connection,
|
||||
template=(
|
||||
"(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM "
|
||||
"(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s "
|
||||
"FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) "
|
||||
"GROUP BY %(expressions)s)"
|
||||
'(SELECT LISTAGG(s) WITHIN GROUP (ORDER BY n DESC) FROM '
|
||||
'(SELECT LEVEL n, SUBSTR(%(expressions)s, LEVEL, 1) s '
|
||||
'FROM DUAL CONNECT BY LEVEL <= LENGTH(%(expressions)s)) '
|
||||
'GROUP BY %(expressions)s)'
|
||||
),
|
||||
**extra_context,
|
||||
**extra_context
|
||||
)
|
||||
|
||||
|
||||
class Right(Left):
|
||||
function = "RIGHT"
|
||||
function = 'RIGHT'
|
||||
|
||||
def get_substr(self):
|
||||
return Substr(
|
||||
self.source_expressions[0], self.source_expressions[1] * Value(-1)
|
||||
)
|
||||
return Substr(self.source_expressions[0], self.source_expressions[1] * Value(-1))
|
||||
|
||||
|
||||
class RPad(LPad):
|
||||
function = "RPAD"
|
||||
function = 'RPAD'
|
||||
|
||||
|
||||
class RTrim(Transform):
|
||||
function = "RTRIM"
|
||||
lookup_name = "rtrim"
|
||||
function = 'RTRIM'
|
||||
lookup_name = 'rtrim'
|
||||
|
||||
|
||||
class SHA1(OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA1"
|
||||
lookup_name = "sha1"
|
||||
function = 'SHA1'
|
||||
lookup_name = 'sha1'
|
||||
|
||||
|
||||
class SHA224(MySQLSHA2Mixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA224"
|
||||
lookup_name = "sha224"
|
||||
function = 'SHA224'
|
||||
lookup_name = 'sha224'
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
raise NotSupportedError("SHA224 is not supported on Oracle.")
|
||||
raise NotSupportedError('SHA224 is not supported on Oracle.')
|
||||
|
||||
|
||||
class SHA256(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA256"
|
||||
lookup_name = "sha256"
|
||||
function = 'SHA256'
|
||||
lookup_name = 'sha256'
|
||||
|
||||
|
||||
class SHA384(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA384"
|
||||
lookup_name = "sha384"
|
||||
function = 'SHA384'
|
||||
lookup_name = 'sha384'
|
||||
|
||||
|
||||
class SHA512(MySQLSHA2Mixin, OracleHashMixin, PostgreSQLSHAMixin, Transform):
|
||||
function = "SHA512"
|
||||
lookup_name = "sha512"
|
||||
function = 'SHA512'
|
||||
lookup_name = 'sha512'
|
||||
|
||||
|
||||
class StrIndex(Func):
|
||||
@@ -307,17 +280,16 @@ class StrIndex(Func):
|
||||
first occurrence of a substring inside another string, or 0 if the
|
||||
substring is not found.
|
||||
"""
|
||||
|
||||
function = "INSTR"
|
||||
function = 'INSTR'
|
||||
arity = 2
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="STRPOS", **extra_context)
|
||||
return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
|
||||
|
||||
|
||||
class Substr(Func):
|
||||
function = "SUBSTRING"
|
||||
function = 'SUBSTRING'
|
||||
output_field = CharField()
|
||||
|
||||
def __init__(self, expression, pos, length=None, **extra):
|
||||
@@ -326,7 +298,7 @@ class Substr(Func):
|
||||
pos: an integer > 0, or an expression returning an integer
|
||||
length: an optional number of characters to return
|
||||
"""
|
||||
if not hasattr(pos, "resolve_expression"):
|
||||
if not hasattr(pos, 'resolve_expression'):
|
||||
if pos < 1:
|
||||
raise ValueError("'pos' must be greater than 0")
|
||||
expressions = [expression, pos]
|
||||
@@ -335,17 +307,17 @@ class Substr(Func):
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
|
||||
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return super().as_sql(compiler, connection, function="SUBSTR", **extra_context)
|
||||
return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
|
||||
|
||||
|
||||
class Trim(Transform):
|
||||
function = "TRIM"
|
||||
lookup_name = "trim"
|
||||
function = 'TRIM'
|
||||
lookup_name = 'trim'
|
||||
|
||||
|
||||
class Upper(Transform):
|
||||
function = "UPPER"
|
||||
lookup_name = "upper"
|
||||
function = 'UPPER'
|
||||
lookup_name = 'upper'
|
||||
|
||||
@@ -2,35 +2,26 @@ from django.db.models.expressions import Func
|
||||
from django.db.models.fields import FloatField, IntegerField
|
||||
|
||||
__all__ = [
|
||||
"CumeDist",
|
||||
"DenseRank",
|
||||
"FirstValue",
|
||||
"Lag",
|
||||
"LastValue",
|
||||
"Lead",
|
||||
"NthValue",
|
||||
"Ntile",
|
||||
"PercentRank",
|
||||
"Rank",
|
||||
"RowNumber",
|
||||
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
|
||||
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
|
||||
]
|
||||
|
||||
|
||||
class CumeDist(Func):
|
||||
function = "CUME_DIST"
|
||||
function = 'CUME_DIST'
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class DenseRank(Func):
|
||||
function = "DENSE_RANK"
|
||||
function = 'DENSE_RANK'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class FirstValue(Func):
|
||||
arity = 1
|
||||
function = "FIRST_VALUE"
|
||||
function = 'FIRST_VALUE'
|
||||
window_compatible = True
|
||||
|
||||
|
||||
@@ -40,12 +31,13 @@ class LagLeadFunction(Func):
|
||||
def __init__(self, expression, offset=1, default=None, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
"%s requires a non-null source expression." % self.__class__.__name__
|
||||
'%s requires a non-null source expression.' %
|
||||
self.__class__.__name__
|
||||
)
|
||||
if offset is None or offset <= 0:
|
||||
raise ValueError(
|
||||
"%s requires a positive integer for the offset."
|
||||
% self.__class__.__name__
|
||||
'%s requires a positive integer for the offset.' %
|
||||
self.__class__.__name__
|
||||
)
|
||||
args = (expression, offset)
|
||||
if default is not None:
|
||||
@@ -58,32 +50,28 @@ class LagLeadFunction(Func):
|
||||
|
||||
|
||||
class Lag(LagLeadFunction):
|
||||
function = "LAG"
|
||||
function = 'LAG'
|
||||
|
||||
|
||||
class LastValue(Func):
|
||||
arity = 1
|
||||
function = "LAST_VALUE"
|
||||
function = 'LAST_VALUE'
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Lead(LagLeadFunction):
|
||||
function = "LEAD"
|
||||
function = 'LEAD'
|
||||
|
||||
|
||||
class NthValue(Func):
|
||||
function = "NTH_VALUE"
|
||||
function = 'NTH_VALUE'
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, expression, nth=1, **extra):
|
||||
if expression is None:
|
||||
raise ValueError(
|
||||
"%s requires a non-null source expression." % self.__class__.__name__
|
||||
)
|
||||
raise ValueError('%s requires a non-null source expression.' % self.__class__.__name__)
|
||||
if nth is None or nth <= 0:
|
||||
raise ValueError(
|
||||
"%s requires a positive integer as for nth." % self.__class__.__name__
|
||||
)
|
||||
raise ValueError('%s requires a positive integer as for nth.' % self.__class__.__name__)
|
||||
super().__init__(expression, nth, **extra)
|
||||
|
||||
def _resolve_output_field(self):
|
||||
@@ -92,29 +80,29 @@ class NthValue(Func):
|
||||
|
||||
|
||||
class Ntile(Func):
|
||||
function = "NTILE"
|
||||
function = 'NTILE'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
def __init__(self, num_buckets=1, **extra):
|
||||
if num_buckets <= 0:
|
||||
raise ValueError("num_buckets must be greater than 0.")
|
||||
raise ValueError('num_buckets must be greater than 0.')
|
||||
super().__init__(num_buckets, **extra)
|
||||
|
||||
|
||||
class PercentRank(Func):
|
||||
function = "PERCENT_RANK"
|
||||
function = 'PERCENT_RANK'
|
||||
output_field = FloatField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class Rank(Func):
|
||||
function = "RANK"
|
||||
function = 'RANK'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
|
||||
class RowNumber(Func):
|
||||
function = "ROW_NUMBER"
|
||||
function = 'ROW_NUMBER'
|
||||
output_field = IntegerField()
|
||||
window_compatible = True
|
||||
|
||||
@@ -5,11 +5,11 @@ from django.db.models.query_utils import Q
|
||||
from django.db.models.sql import Query
|
||||
from django.utils.functional import partition
|
||||
|
||||
__all__ = ["Index"]
|
||||
__all__ = ['Index']
|
||||
|
||||
|
||||
class Index:
|
||||
suffix = "idx"
|
||||
suffix = 'idx'
|
||||
# The max length of the name of the index (restricted to 30 for
|
||||
# cross-database compatibility with Oracle)
|
||||
max_name_length = 30
|
||||
@@ -25,48 +25,46 @@ class Index:
|
||||
include=None,
|
||||
):
|
||||
if opclasses and not name:
|
||||
raise ValueError("An index must be named to use opclasses.")
|
||||
raise ValueError('An index must be named to use opclasses.')
|
||||
if not isinstance(condition, (type(None), Q)):
|
||||
raise ValueError("Index.condition must be a Q instance.")
|
||||
raise ValueError('Index.condition must be a Q instance.')
|
||||
if condition and not name:
|
||||
raise ValueError("An index must be named to use condition.")
|
||||
raise ValueError('An index must be named to use condition.')
|
||||
if not isinstance(fields, (list, tuple)):
|
||||
raise ValueError("Index.fields must be a list or tuple.")
|
||||
raise ValueError('Index.fields must be a list or tuple.')
|
||||
if not isinstance(opclasses, (list, tuple)):
|
||||
raise ValueError("Index.opclasses must be a list or tuple.")
|
||||
raise ValueError('Index.opclasses must be a list or tuple.')
|
||||
if not expressions and not fields:
|
||||
raise ValueError(
|
||||
"At least one field or expression is required to define an index."
|
||||
'At least one field or expression is required to define an '
|
||||
'index.'
|
||||
)
|
||||
if expressions and fields:
|
||||
raise ValueError(
|
||||
"Index.fields and expressions are mutually exclusive.",
|
||||
'Index.fields and expressions are mutually exclusive.',
|
||||
)
|
||||
if expressions and not name:
|
||||
raise ValueError("An index must be named to use expressions.")
|
||||
raise ValueError('An index must be named to use expressions.')
|
||||
if expressions and opclasses:
|
||||
raise ValueError(
|
||||
"Index.opclasses cannot be used with expressions. Use "
|
||||
"django.contrib.postgres.indexes.OpClass() instead."
|
||||
'Index.opclasses cannot be used with expressions. Use '
|
||||
'django.contrib.postgres.indexes.OpClass() instead.'
|
||||
)
|
||||
if opclasses and len(fields) != len(opclasses):
|
||||
raise ValueError(
|
||||
"Index.fields and Index.opclasses must have the same number of "
|
||||
"elements."
|
||||
)
|
||||
raise ValueError('Index.fields and Index.opclasses must have the same number of elements.')
|
||||
if fields and not all(isinstance(field, str) for field in fields):
|
||||
raise ValueError("Index.fields must contain only strings with field names.")
|
||||
raise ValueError('Index.fields must contain only strings with field names.')
|
||||
if include and not name:
|
||||
raise ValueError("A covering index must be named.")
|
||||
raise ValueError('A covering index must be named.')
|
||||
if not isinstance(include, (type(None), list, tuple)):
|
||||
raise ValueError("Index.include must be a list or tuple.")
|
||||
raise ValueError('Index.include must be a list or tuple.')
|
||||
self.fields = list(fields)
|
||||
# A list of 2-tuple with the field name and ordering ('' or 'DESC').
|
||||
self.fields_orders = [
|
||||
(field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "")
|
||||
(field_name[1:], 'DESC') if field_name.startswith('-') else (field_name, '')
|
||||
for field_name in self.fields
|
||||
]
|
||||
self.name = name or ""
|
||||
self.name = name or ''
|
||||
self.db_tablespace = db_tablespace
|
||||
self.opclasses = opclasses
|
||||
self.condition = condition
|
||||
@@ -89,10 +87,8 @@ class Index:
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def create_sql(self, model, schema_editor, using="", **kwargs):
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
def create_sql(self, model, schema_editor, using='', **kwargs):
|
||||
include = [model._meta.get_field(field_name).column for field_name in self.include]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
if self.expressions:
|
||||
index_expressions = []
|
||||
@@ -113,36 +109,29 @@ class Index:
|
||||
col_suffixes = [order[1] for order in self.fields_orders]
|
||||
expressions = None
|
||||
return schema_editor._create_index_sql(
|
||||
model,
|
||||
fields=fields,
|
||||
name=self.name,
|
||||
using=using,
|
||||
db_tablespace=self.db_tablespace,
|
||||
col_suffixes=col_suffixes,
|
||||
opclasses=self.opclasses,
|
||||
condition=condition,
|
||||
include=include,
|
||||
expressions=expressions,
|
||||
**kwargs,
|
||||
model, fields=fields, name=self.name, using=using,
|
||||
db_tablespace=self.db_tablespace, col_suffixes=col_suffixes,
|
||||
opclasses=self.opclasses, condition=condition, include=include,
|
||||
expressions=expressions, **kwargs,
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor, **kwargs):
|
||||
return schema_editor._delete_index_sql(model, self.name, **kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace("django.db.models.indexes", "django.db.models")
|
||||
kwargs = {"name": self.name}
|
||||
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace('django.db.models.indexes', 'django.db.models')
|
||||
kwargs = {'name': self.name}
|
||||
if self.fields:
|
||||
kwargs["fields"] = self.fields
|
||||
kwargs['fields'] = self.fields
|
||||
if self.db_tablespace is not None:
|
||||
kwargs["db_tablespace"] = self.db_tablespace
|
||||
kwargs['db_tablespace'] = self.db_tablespace
|
||||
if self.opclasses:
|
||||
kwargs["opclasses"] = self.opclasses
|
||||
kwargs['opclasses'] = self.opclasses
|
||||
if self.condition:
|
||||
kwargs["condition"] = self.condition
|
||||
kwargs['condition'] = self.condition
|
||||
if self.include:
|
||||
kwargs["include"] = self.include
|
||||
kwargs['include'] = self.include
|
||||
return (path, self.expressions, kwargs)
|
||||
|
||||
def clone(self):
|
||||
@@ -159,44 +148,36 @@ class Index:
|
||||
fit its size by truncating the excess length.
|
||||
"""
|
||||
_, table_name = split_identifier(model._meta.db_table)
|
||||
column_names = [
|
||||
model._meta.get_field(field_name).column
|
||||
for field_name, order in self.fields_orders
|
||||
]
|
||||
column_names = [model._meta.get_field(field_name).column for field_name, order in self.fields_orders]
|
||||
column_names_with_order = [
|
||||
(("-%s" if order else "%s") % column_name)
|
||||
for column_name, (field_name, order) in zip(
|
||||
column_names, self.fields_orders
|
||||
)
|
||||
(('-%s' if order else '%s') % column_name)
|
||||
for column_name, (field_name, order) in zip(column_names, self.fields_orders)
|
||||
]
|
||||
# The length of the parts of the name is based on the default max
|
||||
# length of 30 characters.
|
||||
hash_data = [table_name] + column_names_with_order + [self.suffix]
|
||||
self.name = "%s_%s_%s" % (
|
||||
self.name = '%s_%s_%s' % (
|
||||
table_name[:11],
|
||||
column_names[0][:7],
|
||||
"%s_%s" % (names_digest(*hash_data, length=6), self.suffix),
|
||||
'%s_%s' % (names_digest(*hash_data, length=6), self.suffix),
|
||||
)
|
||||
if len(self.name) > self.max_name_length:
|
||||
raise ValueError(
|
||||
"Index too long for multiple database support. Is self.suffix "
|
||||
"longer than 3 characters?"
|
||||
)
|
||||
if self.name[0] == "_" or self.name[0].isdigit():
|
||||
self.name = "D%s" % self.name[1:]
|
||||
assert len(self.name) <= self.max_name_length, (
|
||||
'Index too long for multiple database support. Is self.suffix '
|
||||
'longer than 3 characters?'
|
||||
)
|
||||
if self.name[0] == '_' or self.name[0].isdigit():
|
||||
self.name = 'D%s' % self.name[1:]
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s:%s%s%s%s%s%s%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
"" if not self.fields else " fields=%s" % repr(self.fields),
|
||||
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
|
||||
"" if not self.name else " name=%s" % repr(self.name),
|
||||
""
|
||||
if self.db_tablespace is None
|
||||
else " db_tablespace=%s" % repr(self.db_tablespace),
|
||||
"" if self.condition is None else " condition=%s" % self.condition,
|
||||
"" if not self.include else " include=%s" % repr(self.include),
|
||||
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
|
||||
return '<%s:%s%s%s%s%s>' % (
|
||||
self.__class__.__name__,
|
||||
'' if not self.fields else " fields='%s'" % ', '.join(self.fields),
|
||||
'' if not self.expressions else " expressions='%s'" % ', '.join([
|
||||
str(expression) for expression in self.expressions
|
||||
]),
|
||||
'' if self.condition is None else ' condition=%s' % self.condition,
|
||||
'' if not self.include else " include='%s'" % ', '.join(self.include),
|
||||
'' if not self.opclasses else " opclasses='%s'" % ', '.join(self.opclasses),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
@@ -207,20 +188,17 @@ class Index:
|
||||
|
||||
class IndexExpression(Func):
|
||||
"""Order and wrap expressions for CREATE INDEX statements."""
|
||||
|
||||
template = "%(expressions)s"
|
||||
template = '%(expressions)s'
|
||||
wrapper_classes = (OrderBy, Collate)
|
||||
|
||||
def set_wrapper_classes(self, connection=None):
|
||||
# Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
|
||||
if connection and connection.features.collate_as_index_expression:
|
||||
self.wrapper_classes = tuple(
|
||||
[
|
||||
wrapper_cls
|
||||
for wrapper_cls in self.wrapper_classes
|
||||
if wrapper_cls is not Collate
|
||||
]
|
||||
)
|
||||
self.wrapper_classes = tuple([
|
||||
wrapper_cls
|
||||
for wrapper_cls in self.wrapper_classes
|
||||
if wrapper_cls is not Collate
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def register_wrappers(cls, *wrapper_classes):
|
||||
@@ -244,17 +222,16 @@ class IndexExpression(Func):
|
||||
if len(wrapper_types) != len(set(wrapper_types)):
|
||||
raise ValueError(
|
||||
"Multiple references to %s can't be used in an indexed "
|
||||
"expression."
|
||||
% ", ".join(
|
||||
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
|
||||
)
|
||||
"expression." % ', '.join([
|
||||
wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
|
||||
])
|
||||
)
|
||||
if expressions[1 : len(wrappers) + 1] != wrappers:
|
||||
if expressions[1:len(wrappers) + 1] != wrappers:
|
||||
raise ValueError(
|
||||
"%s must be topmost expressions in an indexed expression."
|
||||
% ", ".join(
|
||||
[wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
|
||||
)
|
||||
'%s must be topmost expressions in an indexed expression.'
|
||||
% ', '.join([
|
||||
wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes
|
||||
])
|
||||
)
|
||||
# Wrap expressions in parentheses if they are not column references.
|
||||
root_expression = index_expressions[1]
|
||||
@@ -266,7 +243,7 @@ class IndexExpression(Func):
|
||||
for_save,
|
||||
)
|
||||
if not isinstance(resolve_root_expression, Col):
|
||||
root_expression = Func(root_expression, template="(%(expressions)s)")
|
||||
root_expression = Func(root_expression, template='(%(expressions)s)')
|
||||
|
||||
if wrappers:
|
||||
# Order wrappers and set their expressions.
|
||||
@@ -283,9 +260,7 @@ class IndexExpression(Func):
|
||||
else:
|
||||
# Use the root expression, if there are no wrappers.
|
||||
self.set_source_expressions([root_expression])
|
||||
return super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
return super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
# Casting to numeric is unnecessary.
|
||||
|
||||
@@ -1,23 +1,21 @@
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from copy import copy
|
||||
|
||||
from django.core.exceptions import EmptyResultSet
|
||||
from django.db.models.expressions import Case, Expression, Func, Value, When
|
||||
from django.db.models.expressions import Case, Exists, Func, Value, When
|
||||
from django.db.models.fields import (
|
||||
BooleanField,
|
||||
CharField,
|
||||
DateTimeField,
|
||||
Field,
|
||||
IntegerField,
|
||||
UUIDField,
|
||||
CharField, DateTimeField, Field, IntegerField, UUIDField,
|
||||
)
|
||||
from django.db.models.query_utils import RegisterLookupMixin
|
||||
from django.utils.datastructures import OrderedSet
|
||||
from django.utils.deprecation import RemovedInDjango40Warning
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
|
||||
class Lookup(Expression):
|
||||
class Lookup:
|
||||
lookup_name = None
|
||||
prepare_rhs = True
|
||||
can_use_none_as_rhs = False
|
||||
@@ -25,20 +23,18 @@ class Lookup(Expression):
|
||||
def __init__(self, lhs, rhs):
|
||||
self.lhs, self.rhs = lhs, rhs
|
||||
self.rhs = self.get_prep_lookup()
|
||||
self.lhs = self.get_prep_lhs()
|
||||
if hasattr(self.lhs, "get_bilateral_transforms"):
|
||||
if hasattr(self.lhs, 'get_bilateral_transforms'):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
if bilateral_transforms:
|
||||
# Warn the user as soon as possible if they are trying to apply
|
||||
# a bilateral transformation on a nested QuerySet: that won't work.
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
from django.db.models.sql.query import ( # avoid circular import
|
||||
Query,
|
||||
)
|
||||
if isinstance(rhs, Query):
|
||||
raise NotImplementedError(
|
||||
"Bilateral transformations on nested querysets are not implemented."
|
||||
)
|
||||
raise NotImplementedError("Bilateral transformations on nested querysets are not implemented.")
|
||||
self.bilateral_transforms = bilateral_transforms
|
||||
|
||||
def apply_bilateral_transforms(self, value):
|
||||
@@ -46,9 +42,6 @@ class Lookup(Expression):
|
||||
value = transform(value)
|
||||
return value
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self.lhs!r}, {self.rhs!r})"
|
||||
|
||||
def batch_process_rhs(self, compiler, connection, rhs=None):
|
||||
if rhs is None:
|
||||
rhs = self.rhs
|
||||
@@ -63,7 +56,7 @@ class Lookup(Expression):
|
||||
sqls_params.extend(sql_params)
|
||||
else:
|
||||
_, params = self.get_db_prep_lookup(rhs, connection)
|
||||
sqls, sqls_params = ["%s"] * len(params), params
|
||||
sqls, sqls_params = ['%s'] * len(params), params
|
||||
return sqls, sqls_params
|
||||
|
||||
def get_source_expressions(self):
|
||||
@@ -78,32 +71,20 @@ class Lookup(Expression):
|
||||
self.lhs, self.rhs = new_exprs
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if not self.prepare_rhs or hasattr(self.rhs, "resolve_expression"):
|
||||
if hasattr(self.rhs, 'resolve_expression'):
|
||||
return self.rhs
|
||||
if hasattr(self.lhs, "output_field"):
|
||||
if hasattr(self.lhs.output_field, "get_prep_value"):
|
||||
return self.lhs.output_field.get_prep_value(self.rhs)
|
||||
elif self.rhs_is_direct_value():
|
||||
return Value(self.rhs)
|
||||
if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
|
||||
return self.lhs.output_field.get_prep_value(self.rhs)
|
||||
return self.rhs
|
||||
|
||||
def get_prep_lhs(self):
|
||||
if hasattr(self.lhs, "resolve_expression"):
|
||||
return self.lhs
|
||||
return Value(self.lhs)
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
return ("%s", [value])
|
||||
return ('%s', [value])
|
||||
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
lhs = lhs or self.lhs
|
||||
if hasattr(lhs, "resolve_expression"):
|
||||
if hasattr(lhs, 'resolve_expression'):
|
||||
lhs = lhs.resolve_expression(compiler.query)
|
||||
sql, params = compiler.compile(lhs)
|
||||
if isinstance(lhs, Lookup):
|
||||
# Wrapped in parentheses to respect operator precedence.
|
||||
sql = f"({sql})"
|
||||
return sql, params
|
||||
return compiler.compile(lhs)
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
value = self.rhs
|
||||
@@ -114,33 +95,37 @@ class Lookup(Expression):
|
||||
value = Value(value, output_field=self.lhs.output_field)
|
||||
value = self.apply_bilateral_transforms(value)
|
||||
value = value.resolve_expression(compiler.query)
|
||||
if hasattr(value, "as_sql"):
|
||||
sql, params = compiler.compile(value)
|
||||
# Ensure expression is wrapped in parentheses to respect operator
|
||||
# precedence but avoid double wrapping as it can be misinterpreted
|
||||
# on some backends (e.g. subqueries on SQLite).
|
||||
if sql and sql[0] != "(":
|
||||
sql = "(%s)" % sql
|
||||
return sql, params
|
||||
if hasattr(value, 'as_sql'):
|
||||
return compiler.compile(value)
|
||||
else:
|
||||
return self.get_db_prep_lookup(value, connection)
|
||||
|
||||
def rhs_is_direct_value(self):
|
||||
return not hasattr(self.rhs, "as_sql")
|
||||
return not hasattr(self.rhs, 'as_sql')
|
||||
|
||||
def relabeled_clone(self, relabels):
|
||||
new = copy(self)
|
||||
new.lhs = new.lhs.relabeled_clone(relabels)
|
||||
if hasattr(new.rhs, 'relabeled_clone'):
|
||||
new.rhs = new.rhs.relabeled_clone(relabels)
|
||||
return new
|
||||
|
||||
def get_group_by_cols(self, alias=None):
|
||||
cols = []
|
||||
for source in self.get_source_expressions():
|
||||
cols.extend(source.get_group_by_cols())
|
||||
cols = self.lhs.get_group_by_cols()
|
||||
if hasattr(self.rhs, 'get_group_by_cols'):
|
||||
cols.extend(self.rhs.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
raise NotImplementedError
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
# Oracle doesn't allow EXISTS() and filters to be compared to another
|
||||
# expression unless they're wrapped in a CASE WHEN.
|
||||
# Oracle doesn't allow EXISTS() to be compared to another expression
|
||||
# unless it's wrapped in a CASE WHEN.
|
||||
wrapped = False
|
||||
exprs = []
|
||||
for expr in (self.lhs, self.rhs):
|
||||
if connection.ops.conditional_expression_supported_in_where_clause(expr):
|
||||
if isinstance(expr, Exists):
|
||||
expr = Case(When(expr, then=True), default=False)
|
||||
wrapped = True
|
||||
exprs.append(expr)
|
||||
@@ -148,8 +133,16 @@ class Lookup(Expression):
|
||||
return lookup.as_sql(compiler, connection)
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return BooleanField()
|
||||
def contains_aggregate(self):
|
||||
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
|
||||
|
||||
@cached_property
|
||||
def contains_over_clause(self):
|
||||
return self.lhs.contains_over_clause or getattr(self.rhs, 'contains_over_clause', False)
|
||||
|
||||
@property
|
||||
def is_summary(self):
|
||||
return self.lhs.is_summary or getattr(self.rhs, 'is_summary', False)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
@@ -163,34 +156,12 @@ class Lookup(Expression):
|
||||
def __hash__(self):
|
||||
return hash(make_hashable(self.identity))
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
c = self.copy()
|
||||
c.is_summary = summarize
|
||||
c.lhs = self.lhs.resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
c.rhs = self.rhs.resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
return c
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
# Wrap filters with a CASE WHEN expression if a database backend
|
||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||
# BY list.
|
||||
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
|
||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||
return sql, params
|
||||
|
||||
|
||||
class Transform(RegisterLookupMixin, Func):
|
||||
"""
|
||||
RegisterLookupMixin() is first so that get_lookup() and get_transform()
|
||||
first examine self and then check output_field.
|
||||
"""
|
||||
|
||||
bilateral = False
|
||||
arity = 1
|
||||
|
||||
@@ -199,7 +170,7 @@ class Transform(RegisterLookupMixin, Func):
|
||||
return self.get_source_expressions()[0]
|
||||
|
||||
def get_bilateral_transforms(self):
|
||||
if hasattr(self.lhs, "get_bilateral_transforms"):
|
||||
if hasattr(self.lhs, 'get_bilateral_transforms'):
|
||||
bilateral_transforms = self.lhs.get_bilateral_transforms()
|
||||
else:
|
||||
bilateral_transforms = []
|
||||
@@ -213,10 +184,9 @@ class BuiltinLookup(Lookup):
|
||||
lhs_sql, params = super().process_lhs(compiler, connection, lhs)
|
||||
field_internal_type = self.lhs.output_field.get_internal_type()
|
||||
db_type = self.lhs.output_field.db_type(connection=connection)
|
||||
lhs_sql = connection.ops.field_cast_sql(db_type, field_internal_type) % lhs_sql
|
||||
lhs_sql = (
|
||||
connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
|
||||
)
|
||||
lhs_sql = connection.ops.field_cast_sql(
|
||||
db_type, field_internal_type) % lhs_sql
|
||||
lhs_sql = connection.ops.lookup_cast(self.lookup_name, field_internal_type) % lhs_sql
|
||||
return lhs_sql, list(params)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
@@ -224,7 +194,7 @@ class BuiltinLookup(Lookup):
|
||||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
params.extend(rhs_params)
|
||||
rhs_sql = self.get_rhs_op(connection, rhs_sql)
|
||||
return "%s %s" % (lhs_sql, rhs_sql), params
|
||||
return '%s %s' % (lhs_sql, rhs_sql), params
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return connection.operators[self.lookup_name] % rhs
|
||||
@@ -235,22 +205,18 @@ class FieldGetDbPrepValueMixin:
|
||||
Some lookups require Field.get_db_prep_value() to be called on their
|
||||
inputs.
|
||||
"""
|
||||
|
||||
get_db_prep_lookup_value_is_iterable = False
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
# For relational fields, use the 'target_field' attribute of the
|
||||
# output_field.
|
||||
field = getattr(self.lhs.output_field, "target_field", None)
|
||||
get_db_prep_value = (
|
||||
getattr(field, "get_db_prep_value", None)
|
||||
or self.lhs.output_field.get_db_prep_value
|
||||
)
|
||||
field = getattr(self.lhs.output_field, 'target_field', None)
|
||||
get_db_prep_value = getattr(field, 'get_db_prep_value', None) or self.lhs.output_field.get_db_prep_value
|
||||
return (
|
||||
"%s",
|
||||
'%s',
|
||||
[get_db_prep_value(v, connection, prepared=True) for v in value]
|
||||
if self.get_db_prep_lookup_value_is_iterable
|
||||
else [get_db_prep_value(value, connection, prepared=True)],
|
||||
if self.get_db_prep_lookup_value_is_iterable else
|
||||
[get_db_prep_value(value, connection, prepared=True)]
|
||||
)
|
||||
|
||||
|
||||
@@ -259,19 +225,18 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
|
||||
Some lookups require Field.get_db_prep_value() to be called on each value
|
||||
in an iterable.
|
||||
"""
|
||||
|
||||
get_db_prep_lookup_value_is_iterable = True
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if hasattr(self.rhs, "resolve_expression"):
|
||||
if hasattr(self.rhs, 'resolve_expression'):
|
||||
return self.rhs
|
||||
prepared_values = []
|
||||
for rhs_value in self.rhs:
|
||||
if hasattr(rhs_value, "resolve_expression"):
|
||||
if hasattr(rhs_value, 'resolve_expression'):
|
||||
# An expression will be handled by the database but can coexist
|
||||
# alongside real values.
|
||||
pass
|
||||
elif self.prepare_rhs and hasattr(self.lhs.output_field, "get_prep_value"):
|
||||
elif self.prepare_rhs and hasattr(self.lhs.output_field, 'get_prep_value'):
|
||||
rhs_value = self.lhs.output_field.get_prep_value(rhs_value)
|
||||
prepared_values.append(rhs_value)
|
||||
return prepared_values
|
||||
@@ -286,9 +251,9 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
|
||||
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
params = [param]
|
||||
if hasattr(param, "resolve_expression"):
|
||||
if hasattr(param, 'resolve_expression'):
|
||||
param = param.resolve_expression(compiler.query)
|
||||
if hasattr(param, "as_sql"):
|
||||
if hasattr(param, 'as_sql'):
|
||||
sql, params = compiler.compile(param)
|
||||
return sql, params
|
||||
|
||||
@@ -298,44 +263,40 @@ class FieldGetDbPrepValueIterableMixin(FieldGetDbPrepValueMixin):
|
||||
# sql/param pair. Zip them to get sql and param pairs that refer to the
|
||||
# same argument and attempt to replace them with the result of
|
||||
# compiling the param step.
|
||||
sql, params = zip(
|
||||
*(
|
||||
self.resolve_expression_parameter(compiler, connection, sql, param)
|
||||
for sql, param in zip(*pre_processed)
|
||||
)
|
||||
)
|
||||
sql, params = zip(*(
|
||||
self.resolve_expression_parameter(compiler, connection, sql, param)
|
||||
for sql, param in zip(*pre_processed)
|
||||
))
|
||||
params = itertools.chain.from_iterable(params)
|
||||
return sql, tuple(params)
|
||||
|
||||
|
||||
class PostgresOperatorLookup(FieldGetDbPrepValueMixin, Lookup):
|
||||
"""Lookup defined by operators on PostgreSQL."""
|
||||
|
||||
postgres_operator = None
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return "%s %s %s" % (lhs, self.postgres_operator, rhs), params
|
||||
return '%s %s %s' % (lhs, self.postgres_operator, rhs), params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "exact"
|
||||
lookup_name = 'exact'
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
from django.db.models.sql.query import Query
|
||||
|
||||
if isinstance(self.rhs, Query):
|
||||
if self.rhs.has_limit_one():
|
||||
if not self.rhs.has_select_fields:
|
||||
self.rhs.clear_select_clause()
|
||||
self.rhs.add_fields(["pk"])
|
||||
self.rhs.add_fields(['pk'])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The QuerySet value for an exact lookup must be limited to "
|
||||
"one result using slicing."
|
||||
'The QuerySet value for an exact lookup must be limited to '
|
||||
'one result using slicing.'
|
||||
)
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
@@ -344,21 +305,19 @@ class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
# turns "boolfield__exact=True" into "WHERE boolean_field" instead of
|
||||
# "WHERE boolean_field = True" when allowed.
|
||||
if (
|
||||
isinstance(self.rhs, bool)
|
||||
and getattr(self.lhs, "conditional", False)
|
||||
and connection.ops.conditional_expression_supported_in_where_clause(
|
||||
self.lhs
|
||||
)
|
||||
isinstance(self.rhs, bool) and
|
||||
getattr(self.lhs, 'conditional', False) and
|
||||
connection.ops.conditional_expression_supported_in_where_clause(self.lhs)
|
||||
):
|
||||
lhs_sql, params = self.process_lhs(compiler, connection)
|
||||
template = "%s" if self.rhs else "NOT %s"
|
||||
template = '%s' if self.rhs else 'NOT %s'
|
||||
return template % lhs_sql, params
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IExact(BuiltinLookup):
|
||||
lookup_name = "iexact"
|
||||
lookup_name = 'iexact'
|
||||
prepare_rhs = False
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
@@ -370,22 +329,22 @@ class IExact(BuiltinLookup):
|
||||
|
||||
@Field.register_lookup
|
||||
class GreaterThan(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "gt"
|
||||
lookup_name = 'gt'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class GreaterThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "gte"
|
||||
lookup_name = 'gte'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class LessThan(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "lt"
|
||||
lookup_name = 'lt'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class LessThanOrEqual(FieldGetDbPrepValueMixin, BuiltinLookup):
|
||||
lookup_name = "lte"
|
||||
lookup_name = 'lte'
|
||||
|
||||
|
||||
class IntegerFieldFloatRounding:
|
||||
@@ -393,7 +352,6 @@ class IntegerFieldFloatRounding:
|
||||
Allow floats to work as query values for IntegerField. Without this, the
|
||||
decimal portion of the float would always be discarded.
|
||||
"""
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if isinstance(self.rhs, float):
|
||||
self.rhs = math.ceil(self.rhs)
|
||||
@@ -412,10 +370,10 @@ class IntegerLessThan(IntegerFieldFloatRounding, LessThan):
|
||||
|
||||
@Field.register_lookup
|
||||
class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = "in"
|
||||
lookup_name = 'in'
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
db_rhs = getattr(self.rhs, "_db", None)
|
||||
db_rhs = getattr(self.rhs, '_db', None)
|
||||
if db_rhs is not None and db_rhs != connection.alias:
|
||||
raise ValueError(
|
||||
"Subqueries aren't allowed across different databases. Force "
|
||||
@@ -436,39 +394,20 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
# rhs should be an iterable; use batch_process_rhs() to
|
||||
# prepare/transform those values.
|
||||
sqls, sqls_params = self.batch_process_rhs(compiler, connection, rhs)
|
||||
placeholder = "(" + ", ".join(sqls) + ")"
|
||||
placeholder = '(' + ', '.join(sqls) + ')'
|
||||
return (placeholder, sqls_params)
|
||||
else:
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(self.rhs, Query):
|
||||
query = self.rhs
|
||||
query.clear_ordering(clear_default=True)
|
||||
if not query.has_select_fields:
|
||||
query.clear_select_clause()
|
||||
query.add_fields(["pk"])
|
||||
|
||||
if not getattr(self.rhs, 'has_select_fields', True):
|
||||
self.rhs.clear_select_clause()
|
||||
self.rhs.add_fields(['pk'])
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def get_group_by_cols(self, alias=None):
|
||||
cols = self.lhs.get_group_by_cols()
|
||||
if hasattr(self.rhs, "get_group_by_cols"):
|
||||
if not getattr(self.rhs, "has_select_fields", True):
|
||||
self.rhs.clear_select_clause()
|
||||
self.rhs.add_fields(["pk"])
|
||||
cols.extend(self.rhs.get_group_by_cols())
|
||||
return cols
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "IN %s" % rhs
|
||||
return 'IN %s' % rhs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
if (
|
||||
self.rhs_is_direct_value()
|
||||
and max_in_list_size
|
||||
and len(self.rhs) > max_in_list_size
|
||||
):
|
||||
if self.rhs_is_direct_value() and max_in_list_size and len(self.rhs) > max_in_list_size:
|
||||
return self.split_parameter_list_as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
@@ -478,25 +417,25 @@ class In(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
max_in_list_size = connection.ops.max_in_list_size()
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.batch_process_rhs(compiler, connection)
|
||||
in_clause_elements = ["("]
|
||||
in_clause_elements = ['(']
|
||||
params = []
|
||||
for offset in range(0, len(rhs_params), max_in_list_size):
|
||||
if offset > 0:
|
||||
in_clause_elements.append(" OR ")
|
||||
in_clause_elements.append("%s IN (" % lhs)
|
||||
in_clause_elements.append(' OR ')
|
||||
in_clause_elements.append('%s IN (' % lhs)
|
||||
params.extend(lhs_params)
|
||||
sqls = rhs[offset : offset + max_in_list_size]
|
||||
sqls_params = rhs_params[offset : offset + max_in_list_size]
|
||||
param_group = ", ".join(sqls)
|
||||
sqls = rhs[offset: offset + max_in_list_size]
|
||||
sqls_params = rhs_params[offset: offset + max_in_list_size]
|
||||
param_group = ', '.join(sqls)
|
||||
in_clause_elements.append(param_group)
|
||||
in_clause_elements.append(")")
|
||||
in_clause_elements.append(')')
|
||||
params.extend(sqls_params)
|
||||
in_clause_elements.append(")")
|
||||
return "".join(in_clause_elements), params
|
||||
in_clause_elements.append(')')
|
||||
return ''.join(in_clause_elements), params
|
||||
|
||||
|
||||
class PatternLookup(BuiltinLookup):
|
||||
param_pattern = "%%%s%%"
|
||||
param_pattern = '%%%s%%'
|
||||
prepare_rhs = False
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
@@ -509,10 +448,8 @@ class PatternLookup(BuiltinLookup):
|
||||
# So, for Python values we don't need any special pattern, but for
|
||||
# SQL reference values or SQL transformations we need the correct
|
||||
# pattern added.
|
||||
if hasattr(self.rhs, "as_sql") or self.bilateral_transforms:
|
||||
pattern = connection.pattern_ops[self.lookup_name].format(
|
||||
connection.pattern_esc
|
||||
)
|
||||
if hasattr(self.rhs, 'as_sql') or self.bilateral_transforms:
|
||||
pattern = connection.pattern_ops[self.lookup_name].format(connection.pattern_esc)
|
||||
return pattern.format(rhs)
|
||||
else:
|
||||
return super().get_rhs_op(connection, rhs)
|
||||
@@ -520,47 +457,45 @@ class PatternLookup(BuiltinLookup):
|
||||
def process_rhs(self, qn, connection):
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
|
||||
params[0] = self.param_pattern % connection.ops.prep_for_like_query(
|
||||
params[0]
|
||||
)
|
||||
params[0] = self.param_pattern % connection.ops.prep_for_like_query(params[0])
|
||||
return rhs, params
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Contains(PatternLookup):
|
||||
lookup_name = "contains"
|
||||
lookup_name = 'contains'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IContains(Contains):
|
||||
lookup_name = "icontains"
|
||||
lookup_name = 'icontains'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class StartsWith(PatternLookup):
|
||||
lookup_name = "startswith"
|
||||
param_pattern = "%s%%"
|
||||
lookup_name = 'startswith'
|
||||
param_pattern = '%s%%'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IStartsWith(StartsWith):
|
||||
lookup_name = "istartswith"
|
||||
lookup_name = 'istartswith'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class EndsWith(PatternLookup):
|
||||
lookup_name = "endswith"
|
||||
param_pattern = "%%%s"
|
||||
lookup_name = 'endswith'
|
||||
param_pattern = '%%%s'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class IEndsWith(EndsWith):
|
||||
lookup_name = "iendswith"
|
||||
lookup_name = 'iendswith'
|
||||
|
||||
|
||||
@Field.register_lookup
|
||||
class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
lookup_name = "range"
|
||||
lookup_name = 'range'
|
||||
|
||||
def get_rhs_op(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s" % (rhs[0], rhs[1])
|
||||
@@ -568,13 +503,20 @@ class Range(FieldGetDbPrepValueIterableMixin, BuiltinLookup):
|
||||
|
||||
@Field.register_lookup
|
||||
class IsNull(BuiltinLookup):
|
||||
lookup_name = "isnull"
|
||||
lookup_name = 'isnull'
|
||||
prepare_rhs = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not isinstance(self.rhs, bool):
|
||||
raise ValueError(
|
||||
"The QuerySet value for an isnull lookup must be True or False."
|
||||
# When the deprecation ends, replace with:
|
||||
# raise ValueError(
|
||||
# 'The QuerySet value for an isnull lookup must be True or '
|
||||
# 'False.'
|
||||
# )
|
||||
warnings.warn(
|
||||
'Using a non-boolean value for an isnull lookup is '
|
||||
'deprecated, use True or False instead.',
|
||||
RemovedInDjango40Warning,
|
||||
)
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
if self.rhs:
|
||||
@@ -585,7 +527,7 @@ class IsNull(BuiltinLookup):
|
||||
|
||||
@Field.register_lookup
|
||||
class Regex(BuiltinLookup):
|
||||
lookup_name = "regex"
|
||||
lookup_name = 'regex'
|
||||
prepare_rhs = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
@@ -600,25 +542,16 @@ class Regex(BuiltinLookup):
|
||||
|
||||
@Field.register_lookup
|
||||
class IRegex(Regex):
|
||||
lookup_name = "iregex"
|
||||
lookup_name = 'iregex'
|
||||
|
||||
|
||||
class YearLookup(Lookup):
|
||||
def year_lookup_bounds(self, connection, year):
|
||||
from django.db.models.functions import ExtractIsoYear
|
||||
|
||||
iso_year = isinstance(self.lhs, ExtractIsoYear)
|
||||
output_field = self.lhs.lhs.output_field
|
||||
if isinstance(output_field, DateTimeField):
|
||||
bounds = connection.ops.year_lookup_bounds_for_datetime_field(
|
||||
year,
|
||||
iso_year=iso_year,
|
||||
)
|
||||
bounds = connection.ops.year_lookup_bounds_for_datetime_field(year)
|
||||
else:
|
||||
bounds = connection.ops.year_lookup_bounds_for_date_field(
|
||||
year,
|
||||
iso_year=iso_year,
|
||||
)
|
||||
bounds = connection.ops.year_lookup_bounds_for_date_field(year)
|
||||
return bounds
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
@@ -632,7 +565,7 @@ class YearLookup(Lookup):
|
||||
rhs_sql = self.get_direct_rhs_sql(connection, rhs_sql)
|
||||
start, finish = self.year_lookup_bounds(connection, self.rhs)
|
||||
params.extend(self.get_bound_params(start, finish))
|
||||
return "%s %s" % (lhs_sql, rhs_sql), params
|
||||
return '%s %s' % (lhs_sql, rhs_sql), params
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
def get_direct_rhs_sql(self, connection, rhs):
|
||||
@@ -640,13 +573,13 @@ class YearLookup(Lookup):
|
||||
|
||||
def get_bound_params(self, start, finish):
|
||||
raise NotImplementedError(
|
||||
"subclasses of YearLookup must provide a get_bound_params() method"
|
||||
'subclasses of YearLookup must provide a get_bound_params() method'
|
||||
)
|
||||
|
||||
|
||||
class YearExact(YearLookup, Exact):
|
||||
def get_direct_rhs_sql(self, connection, rhs):
|
||||
return "BETWEEN %s AND %s"
|
||||
return 'BETWEEN %s AND %s'
|
||||
|
||||
def get_bound_params(self, start, finish):
|
||||
return (start, finish)
|
||||
@@ -677,16 +610,12 @@ class UUIDTextMixin:
|
||||
Strip hyphens from a value when filtering a UUIDField on backends without
|
||||
a native datatype for UUID.
|
||||
"""
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
if not connection.features.has_native_uuid_field:
|
||||
from django.db.models.functions import Replace
|
||||
|
||||
if self.rhs_is_direct_value():
|
||||
self.rhs = Value(self.rhs)
|
||||
self.rhs = Replace(
|
||||
self.rhs, Value("-"), Value(""), output_field=CharField()
|
||||
)
|
||||
self.rhs = Replace(self.rhs, Value('-'), Value(''), output_field=CharField())
|
||||
rhs, params = super().process_rhs(qn, connection)
|
||||
return rhs, params
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class BaseManager:
|
||||
|
||||
def __str__(self):
|
||||
"""Return "app_label.model_label.manager_name"."""
|
||||
return "%s.%s" % (self.model._meta.label, self.name)
|
||||
return '%s.%s' % (self.model._meta.label, self.name)
|
||||
|
||||
def __class_getitem__(cls, *args, **kwargs):
|
||||
return cls
|
||||
@@ -46,12 +46,12 @@ class BaseManager:
|
||||
Raise a ValueError if the manager is dynamically generated.
|
||||
"""
|
||||
qs_class = self._queryset_class
|
||||
if getattr(self, "_built_with_as_manager", False):
|
||||
if getattr(self, '_built_with_as_manager', False):
|
||||
# using MyQuerySet.as_manager()
|
||||
return (
|
||||
True, # as_manager
|
||||
None, # manager_class
|
||||
"%s.%s" % (qs_class.__module__, qs_class.__name__), # qs_class
|
||||
'%s.%s' % (qs_class.__module__, qs_class.__name__), # qs_class
|
||||
None, # args
|
||||
None, # kwargs
|
||||
)
|
||||
@@ -69,7 +69,7 @@ class BaseManager:
|
||||
)
|
||||
return (
|
||||
False, # as_manager
|
||||
"%s.%s" % (module_name, name), # manager_class
|
||||
'%s.%s' % (module_name, name), # manager_class
|
||||
None, # qs_class
|
||||
self._constructor_args[0], # args
|
||||
self._constructor_args[1], # kwargs
|
||||
@@ -83,22 +83,18 @@ class BaseManager:
|
||||
def create_method(name, method):
|
||||
def manager_method(self, *args, **kwargs):
|
||||
return getattr(self.get_queryset(), name)(*args, **kwargs)
|
||||
|
||||
manager_method.__name__ = method.__name__
|
||||
manager_method.__doc__ = method.__doc__
|
||||
return manager_method
|
||||
|
||||
new_methods = {}
|
||||
for name, method in inspect.getmembers(
|
||||
queryset_class, predicate=inspect.isfunction
|
||||
):
|
||||
for name, method in inspect.getmembers(queryset_class, predicate=inspect.isfunction):
|
||||
# Only copy missing methods.
|
||||
if hasattr(cls, name):
|
||||
continue
|
||||
# Only copy public methods or methods with the attribute
|
||||
# queryset_only=False.
|
||||
queryset_only = getattr(method, "queryset_only", None)
|
||||
if queryset_only or (queryset_only is None and name.startswith("_")):
|
||||
# Only copy public methods or methods with the attribute `queryset_only=False`.
|
||||
queryset_only = getattr(method, 'queryset_only', None)
|
||||
if queryset_only or (queryset_only is None and name.startswith('_')):
|
||||
continue
|
||||
# Copy the method onto the manager.
|
||||
new_methods[name] = create_method(name, method)
|
||||
@@ -107,15 +103,11 @@ class BaseManager:
|
||||
@classmethod
|
||||
def from_queryset(cls, queryset_class, class_name=None):
|
||||
if class_name is None:
|
||||
class_name = "%sFrom%s" % (cls.__name__, queryset_class.__name__)
|
||||
return type(
|
||||
class_name,
|
||||
(cls,),
|
||||
{
|
||||
"_queryset_class": queryset_class,
|
||||
**cls._get_queryset_methods(queryset_class),
|
||||
},
|
||||
)
|
||||
class_name = '%sFrom%s' % (cls.__name__, queryset_class.__name__)
|
||||
return type(class_name, (cls,), {
|
||||
'_queryset_class': queryset_class,
|
||||
**cls._get_queryset_methods(queryset_class),
|
||||
})
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
self.name = self.name or name
|
||||
@@ -165,8 +157,8 @@ class BaseManager:
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, self.__class__)
|
||||
and self._constructor_args == other._constructor_args
|
||||
isinstance(other, self.__class__) and
|
||||
self._constructor_args == other._constructor_args
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
@@ -178,24 +170,22 @@ class Manager(BaseManager.from_queryset(QuerySet)):
|
||||
|
||||
|
||||
class ManagerDescriptor:
|
||||
|
||||
def __init__(self, manager):
|
||||
self.manager = manager
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is not None:
|
||||
raise AttributeError(
|
||||
"Manager isn't accessible via %s instances" % cls.__name__
|
||||
)
|
||||
raise AttributeError("Manager isn't accessible via %s instances" % cls.__name__)
|
||||
|
||||
if cls._meta.abstract:
|
||||
raise AttributeError(
|
||||
"Manager isn't available; %s is abstract" % (cls._meta.object_name,)
|
||||
)
|
||||
raise AttributeError("Manager isn't available; %s is abstract" % (
|
||||
cls._meta.object_name,
|
||||
))
|
||||
|
||||
if cls._meta.swapped:
|
||||
raise AttributeError(
|
||||
"Manager isn't available; '%s' has been swapped for '%s'"
|
||||
% (
|
||||
"Manager isn't available; '%s' has been swapped for '%s'" % (
|
||||
cls._meta.label,
|
||||
cls._meta.swapped,
|
||||
)
|
||||
|
||||
@@ -20,37 +20,18 @@ PROXY_PARENTS = object()
|
||||
EMPTY_RELATION_TREE = ()
|
||||
|
||||
IMMUTABLE_WARNING = (
|
||||
"The return type of '%s' should never be mutated. If you want to manipulate this "
|
||||
"list for your own use, make a copy first."
|
||||
"The return type of '%s' should never be mutated. If you want to manipulate this list "
|
||||
"for your own use, make a copy first."
|
||||
)
|
||||
|
||||
DEFAULT_NAMES = (
|
||||
"verbose_name",
|
||||
"verbose_name_plural",
|
||||
"db_table",
|
||||
"ordering",
|
||||
"unique_together",
|
||||
"permissions",
|
||||
"get_latest_by",
|
||||
"order_with_respect_to",
|
||||
"app_label",
|
||||
"db_tablespace",
|
||||
"abstract",
|
||||
"managed",
|
||||
"proxy",
|
||||
"swappable",
|
||||
"auto_created",
|
||||
"index_together",
|
||||
"apps",
|
||||
"default_permissions",
|
||||
"select_on_save",
|
||||
"default_related_name",
|
||||
"required_db_features",
|
||||
"required_db_vendor",
|
||||
"base_manager_name",
|
||||
"default_manager_name",
|
||||
"indexes",
|
||||
"constraints",
|
||||
'verbose_name', 'verbose_name_plural', 'db_table', 'ordering',
|
||||
'unique_together', 'permissions', 'get_latest_by', 'order_with_respect_to',
|
||||
'app_label', 'db_tablespace', 'abstract', 'managed', 'proxy', 'swappable',
|
||||
'auto_created', 'index_together', 'apps', 'default_permissions',
|
||||
'select_on_save', 'default_related_name', 'required_db_features',
|
||||
'required_db_vendor', 'base_manager_name', 'default_manager_name',
|
||||
'indexes', 'constraints',
|
||||
)
|
||||
|
||||
|
||||
@@ -82,17 +63,11 @@ def make_immutable_fields_list(name, data):
|
||||
|
||||
class Options:
|
||||
FORWARD_PROPERTIES = {
|
||||
"fields",
|
||||
"many_to_many",
|
||||
"concrete_fields",
|
||||
"local_concrete_fields",
|
||||
"_forward_fields_map",
|
||||
"managers",
|
||||
"managers_map",
|
||||
"base_manager",
|
||||
"default_manager",
|
||||
'fields', 'many_to_many', 'concrete_fields', 'local_concrete_fields',
|
||||
'_forward_fields_map', 'managers', 'managers_map', 'base_manager',
|
||||
'default_manager',
|
||||
}
|
||||
REVERSE_PROPERTIES = {"related_objects", "fields_map", "_relation_tree"}
|
||||
REVERSE_PROPERTIES = {'related_objects', 'fields_map', '_relation_tree'}
|
||||
|
||||
default_apps = apps
|
||||
|
||||
@@ -107,7 +82,7 @@ class Options:
|
||||
self.model_name = None
|
||||
self.verbose_name = None
|
||||
self.verbose_name_plural = None
|
||||
self.db_table = ""
|
||||
self.db_table = ''
|
||||
self.ordering = []
|
||||
self._ordering_clash = False
|
||||
self.indexes = []
|
||||
@@ -115,7 +90,7 @@ class Options:
|
||||
self.unique_together = []
|
||||
self.index_together = []
|
||||
self.select_on_save = False
|
||||
self.default_permissions = ("add", "change", "delete", "view")
|
||||
self.default_permissions = ('add', 'change', 'delete', 'view')
|
||||
self.permissions = []
|
||||
self.object_name = None
|
||||
self.app_label = app_label
|
||||
@@ -155,11 +130,11 @@ class Options:
|
||||
|
||||
@property
|
||||
def label(self):
|
||||
return "%s.%s" % (self.app_label, self.object_name)
|
||||
return '%s.%s' % (self.app_label, self.object_name)
|
||||
|
||||
@property
|
||||
def label_lower(self):
|
||||
return "%s.%s" % (self.app_label, self.model_name)
|
||||
return '%s.%s' % (self.app_label, self.model_name)
|
||||
|
||||
@property
|
||||
def app_config(self):
|
||||
@@ -192,7 +167,7 @@ class Options:
|
||||
# Ignore any private attributes that Django doesn't care about.
|
||||
# NOTE: We can't modify a dictionary's contents while looping
|
||||
# over it, so we loop over the *original* dictionary instead.
|
||||
if name.startswith("_"):
|
||||
if name.startswith('_'):
|
||||
del meta_attrs[name]
|
||||
for attr_name in DEFAULT_NAMES:
|
||||
if attr_name in meta_attrs:
|
||||
@@ -206,34 +181,30 @@ class Options:
|
||||
self.index_together = normalize_together(self.index_together)
|
||||
# App label/class name interpolation for names of constraints and
|
||||
# indexes.
|
||||
if not getattr(cls._meta, "abstract", False):
|
||||
for attr_name in {"constraints", "indexes"}:
|
||||
if not getattr(cls._meta, 'abstract', False):
|
||||
for attr_name in {'constraints', 'indexes'}:
|
||||
objs = getattr(self, attr_name, [])
|
||||
setattr(self, attr_name, self._format_names_with_class(cls, objs))
|
||||
|
||||
# verbose_name_plural is a special case because it uses a 's'
|
||||
# by default.
|
||||
if self.verbose_name_plural is None:
|
||||
self.verbose_name_plural = format_lazy("{}s", self.verbose_name)
|
||||
self.verbose_name_plural = format_lazy('{}s', self.verbose_name)
|
||||
|
||||
# order_with_respect_and ordering are mutually exclusive.
|
||||
self._ordering_clash = bool(self.ordering and self.order_with_respect_to)
|
||||
|
||||
# Any leftover attributes must be invalid.
|
||||
if meta_attrs != {}:
|
||||
raise TypeError(
|
||||
"'class Meta' got invalid attribute(s): %s" % ",".join(meta_attrs)
|
||||
)
|
||||
raise TypeError("'class Meta' got invalid attribute(s): %s" % ','.join(meta_attrs))
|
||||
else:
|
||||
self.verbose_name_plural = format_lazy("{}s", self.verbose_name)
|
||||
self.verbose_name_plural = format_lazy('{}s', self.verbose_name)
|
||||
del self.meta
|
||||
|
||||
# If the db_table wasn't provided, use the app_label + model_name.
|
||||
if not self.db_table:
|
||||
self.db_table = "%s_%s" % (self.app_label, self.model_name)
|
||||
self.db_table = truncate_name(
|
||||
self.db_table, connection.ops.max_name_length()
|
||||
)
|
||||
self.db_table = truncate_name(self.db_table, connection.ops.max_name_length())
|
||||
|
||||
def _format_names_with_class(self, cls, objs):
|
||||
"""App label/class name interpolation for object names."""
|
||||
@@ -241,8 +212,8 @@ class Options:
|
||||
for obj in objs:
|
||||
obj = obj.clone()
|
||||
obj.name = obj.name % {
|
||||
"app_label": cls._meta.app_label.lower(),
|
||||
"class": cls.__name__.lower(),
|
||||
'app_label': cls._meta.app_label.lower(),
|
||||
'class': cls.__name__.lower(),
|
||||
}
|
||||
new_objs.append(obj)
|
||||
return new_objs
|
||||
@@ -250,19 +221,19 @@ class Options:
|
||||
def _get_default_pk_class(self):
|
||||
pk_class_path = getattr(
|
||||
self.app_config,
|
||||
"default_auto_field",
|
||||
'default_auto_field',
|
||||
settings.DEFAULT_AUTO_FIELD,
|
||||
)
|
||||
if self.app_config and self.app_config._is_default_auto_field_overridden:
|
||||
app_config_class = type(self.app_config)
|
||||
source = (
|
||||
f"{app_config_class.__module__}."
|
||||
f"{app_config_class.__qualname__}.default_auto_field"
|
||||
f'{app_config_class.__module__}.'
|
||||
f'{app_config_class.__qualname__}.default_auto_field'
|
||||
)
|
||||
else:
|
||||
source = "DEFAULT_AUTO_FIELD"
|
||||
source = 'DEFAULT_AUTO_FIELD'
|
||||
if not pk_class_path:
|
||||
raise ImproperlyConfigured(f"{source} must not be empty.")
|
||||
raise ImproperlyConfigured(f'{source} must not be empty.')
|
||||
try:
|
||||
pk_class = import_string(pk_class_path)
|
||||
except ImportError as e:
|
||||
@@ -285,20 +256,15 @@ class Options:
|
||||
query = self.order_with_respect_to
|
||||
try:
|
||||
self.order_with_respect_to = next(
|
||||
f
|
||||
for f in self._get_fields(reverse=False)
|
||||
f for f in self._get_fields(reverse=False)
|
||||
if f.name == query or f.attname == query
|
||||
)
|
||||
except StopIteration:
|
||||
raise FieldDoesNotExist(
|
||||
"%s has no field named '%s'" % (self.object_name, query)
|
||||
)
|
||||
raise FieldDoesNotExist("%s has no field named '%s'" % (self.object_name, query))
|
||||
|
||||
self.ordering = ("_order",)
|
||||
if not any(
|
||||
isinstance(field, OrderWrt) for field in model._meta.local_fields
|
||||
):
|
||||
model.add_to_class("_order", OrderWrt())
|
||||
self.ordering = ('_order',)
|
||||
if not any(isinstance(field, OrderWrt) for field in model._meta.local_fields):
|
||||
model.add_to_class('_order', OrderWrt())
|
||||
else:
|
||||
self.order_with_respect_to = None
|
||||
|
||||
@@ -310,17 +276,15 @@ class Options:
|
||||
# Look for a local field with the same name as the
|
||||
# first parent link. If a local field has already been
|
||||
# created, use it instead of promoting the parent
|
||||
already_created = [
|
||||
fld for fld in self.local_fields if fld.name == field.name
|
||||
]
|
||||
already_created = [fld for fld in self.local_fields if fld.name == field.name]
|
||||
if already_created:
|
||||
field = already_created[0]
|
||||
field.primary_key = True
|
||||
self.setup_pk(field)
|
||||
else:
|
||||
pk_class = self._get_default_pk_class()
|
||||
auto = pk_class(verbose_name="ID", primary_key=True, auto_created=True)
|
||||
model.add_to_class("id", auto)
|
||||
auto = pk_class(verbose_name='ID', primary_key=True, auto_created=True)
|
||||
model.add_to_class('id', auto)
|
||||
|
||||
def add_manager(self, manager):
|
||||
self.local_managers.append(manager)
|
||||
@@ -347,11 +311,7 @@ class Options:
|
||||
# ideally, we'd just ask for field.related_model. However, related_model
|
||||
# is a cached property, and all the models haven't been loaded yet, so
|
||||
# we need to make sure we don't cache a string reference.
|
||||
if (
|
||||
field.is_relation
|
||||
and hasattr(field.remote_field, "model")
|
||||
and field.remote_field.model
|
||||
):
|
||||
if field.is_relation and hasattr(field.remote_field, 'model') and field.remote_field.model:
|
||||
try:
|
||||
field.remote_field.model._meta._expire_cache(forward=False)
|
||||
except AttributeError:
|
||||
@@ -375,7 +335,7 @@ class Options:
|
||||
self.db_table = target._meta.db_table
|
||||
|
||||
def __repr__(self):
|
||||
return "<Options for %s>" % self.object_name
|
||||
return '<Options for %s>' % self.object_name
|
||||
|
||||
def __str__(self):
|
||||
return self.label_lower
|
||||
@@ -392,10 +352,8 @@ class Options:
|
||||
if self.required_db_vendor:
|
||||
return self.required_db_vendor == connection.vendor
|
||||
if self.required_db_features:
|
||||
return all(
|
||||
getattr(connection.features, feat, False)
|
||||
for feat in self.required_db_features
|
||||
)
|
||||
return all(getattr(connection.features, feat, False)
|
||||
for feat in self.required_db_features)
|
||||
return True
|
||||
|
||||
@property
|
||||
@@ -417,7 +375,7 @@ class Options:
|
||||
swapped_for = getattr(settings, self.swappable, None)
|
||||
if swapped_for:
|
||||
try:
|
||||
swapped_label, swapped_object = swapped_for.split(".")
|
||||
swapped_label, swapped_object = swapped_for.split('.')
|
||||
except ValueError:
|
||||
# setting not in the format app_label.model_name
|
||||
# raising ImproperlyConfigured here causes problems with
|
||||
@@ -425,10 +383,7 @@ class Options:
|
||||
# or as part of validation.
|
||||
return swapped_for
|
||||
|
||||
if (
|
||||
"%s.%s" % (swapped_label, swapped_object.lower())
|
||||
!= self.label_lower
|
||||
):
|
||||
if '%s.%s' % (swapped_label, swapped_object.lower()) != self.label_lower:
|
||||
return swapped_for
|
||||
return None
|
||||
|
||||
@@ -436,7 +391,7 @@ class Options:
|
||||
def managers(self):
|
||||
managers = []
|
||||
seen_managers = set()
|
||||
bases = (b for b in self.model.mro() if hasattr(b, "_meta"))
|
||||
bases = (b for b in self.model.mro() if hasattr(b, '_meta'))
|
||||
for depth, base in enumerate(bases):
|
||||
for manager in base._meta.local_managers:
|
||||
if manager.name in seen_managers:
|
||||
@@ -462,8 +417,8 @@ class Options:
|
||||
if not base_manager_name:
|
||||
# Get the first parent's base_manager_name if there's one.
|
||||
for parent in self.model.mro()[1:]:
|
||||
if hasattr(parent, "_meta"):
|
||||
if parent._base_manager.name != "_base_manager":
|
||||
if hasattr(parent, '_meta'):
|
||||
if parent._base_manager.name != '_base_manager':
|
||||
base_manager_name = parent._base_manager.name
|
||||
break
|
||||
|
||||
@@ -472,15 +427,14 @@ class Options:
|
||||
return self.managers_map[base_manager_name]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"%s has no manager named %r"
|
||||
% (
|
||||
"%s has no manager named %r" % (
|
||||
self.object_name,
|
||||
base_manager_name,
|
||||
)
|
||||
)
|
||||
|
||||
manager = Manager()
|
||||
manager.name = "_base_manager"
|
||||
manager.name = '_base_manager'
|
||||
manager.model = self.model
|
||||
manager.auto_created = True
|
||||
return manager
|
||||
@@ -491,7 +445,7 @@ class Options:
|
||||
if not default_manager_name and not self.local_managers:
|
||||
# Get the first parent's default_manager_name if there's one.
|
||||
for parent in self.model.mro()[1:]:
|
||||
if hasattr(parent, "_meta"):
|
||||
if hasattr(parent, '_meta'):
|
||||
default_manager_name = parent._meta.default_manager_name
|
||||
break
|
||||
|
||||
@@ -500,8 +454,7 @@ class Options:
|
||||
return self.managers_map[default_manager_name]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"%s has no manager named %r"
|
||||
% (
|
||||
"%s has no manager named %r" % (
|
||||
self.object_name,
|
||||
default_manager_name,
|
||||
)
|
||||
@@ -535,20 +488,13 @@ class Options:
|
||||
|
||||
def is_not_a_generic_foreign_key(f):
|
||||
return not (
|
||||
f.is_relation
|
||||
and f.many_to_one
|
||||
and not (hasattr(f.remote_field, "model") and f.remote_field.model)
|
||||
f.is_relation and f.many_to_one and not (hasattr(f.remote_field, 'model') and f.remote_field.model)
|
||||
)
|
||||
|
||||
return make_immutable_fields_list(
|
||||
"fields",
|
||||
(
|
||||
f
|
||||
for f in self._get_fields(reverse=False)
|
||||
if is_not_an_m2m_field(f)
|
||||
and is_not_a_generic_relation(f)
|
||||
and is_not_a_generic_foreign_key(f)
|
||||
),
|
||||
(f for f in self._get_fields(reverse=False)
|
||||
if is_not_an_m2m_field(f) and is_not_a_generic_relation(f) and is_not_a_generic_foreign_key(f))
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@@ -588,11 +534,7 @@ class Options:
|
||||
"""
|
||||
return make_immutable_fields_list(
|
||||
"many_to_many",
|
||||
(
|
||||
f
|
||||
for f in self._get_fields(reverse=False)
|
||||
if f.is_relation and f.many_to_many
|
||||
),
|
||||
(f for f in self._get_fields(reverse=False) if f.is_relation and f.many_to_many)
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@@ -606,16 +548,10 @@ class Options:
|
||||
combined with filtering of field properties is the public API for
|
||||
obtaining this field list.
|
||||
"""
|
||||
all_related_fields = self._get_fields(
|
||||
forward=False, reverse=True, include_hidden=True
|
||||
)
|
||||
all_related_fields = self._get_fields(forward=False, reverse=True, include_hidden=True)
|
||||
return make_immutable_fields_list(
|
||||
"related_objects",
|
||||
(
|
||||
obj
|
||||
for obj in all_related_fields
|
||||
if not obj.hidden or obj.field.many_to_many
|
||||
),
|
||||
(obj for obj in all_related_fields if not obj.hidden or obj.field.many_to_many)
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@@ -671,9 +607,7 @@ class Options:
|
||||
# field map.
|
||||
return self.fields_map[field_name]
|
||||
except KeyError:
|
||||
raise FieldDoesNotExist(
|
||||
"%s has no field named '%s'" % (self.object_name, field_name)
|
||||
)
|
||||
raise FieldDoesNotExist("%s has no field named '%s'" % (self.object_name, field_name))
|
||||
|
||||
def get_base_chain(self, model):
|
||||
"""
|
||||
@@ -742,17 +676,15 @@ class Options:
|
||||
final_field = opts.parents[int_model]
|
||||
targets = (final_field.remote_field.get_related_field(),)
|
||||
opts = int_model._meta
|
||||
path.append(
|
||||
PathInfo(
|
||||
from_opts=final_field.model._meta,
|
||||
to_opts=opts,
|
||||
target_fields=targets,
|
||||
join_field=final_field,
|
||||
m2m=False,
|
||||
direct=True,
|
||||
filtered_relation=None,
|
||||
)
|
||||
)
|
||||
path.append(PathInfo(
|
||||
from_opts=final_field.model._meta,
|
||||
to_opts=opts,
|
||||
target_fields=targets,
|
||||
join_field=final_field,
|
||||
m2m=False,
|
||||
direct=True,
|
||||
filtered_relation=None,
|
||||
))
|
||||
return path
|
||||
|
||||
def get_path_from_parent(self, parent):
|
||||
@@ -794,8 +726,7 @@ class Options:
|
||||
if opts.abstract:
|
||||
continue
|
||||
fields_with_relations = (
|
||||
f
|
||||
for f in opts._get_fields(reverse=False, include_parents=False)
|
||||
f for f in opts._get_fields(reverse=False, include_parents=False)
|
||||
if f.is_relation and f.related_model is not None
|
||||
)
|
||||
for f in fields_with_relations:
|
||||
@@ -809,13 +740,11 @@ class Options:
|
||||
# __dict__ takes precedence over a data descriptor (such as
|
||||
# @cached_property). This means that the _meta._relation_tree is
|
||||
# only called if related_objects is not in __dict__.
|
||||
related_objects = related_objects_graph[
|
||||
model._meta.concrete_model._meta.label
|
||||
]
|
||||
model._meta.__dict__["_relation_tree"] = related_objects
|
||||
related_objects = related_objects_graph[model._meta.concrete_model._meta.label]
|
||||
model._meta.__dict__['_relation_tree'] = related_objects
|
||||
# It seems it is possible that self is not in all_models, so guard
|
||||
# against that with default for get().
|
||||
return self.__dict__.get("_relation_tree", EMPTY_RELATION_TREE)
|
||||
return self.__dict__.get('_relation_tree', EMPTY_RELATION_TREE)
|
||||
|
||||
@cached_property
|
||||
def _relation_tree(self):
|
||||
@@ -846,18 +775,10 @@ class Options:
|
||||
"""
|
||||
if include_parents is False:
|
||||
include_parents = PROXY_PARENTS
|
||||
return self._get_fields(
|
||||
include_parents=include_parents, include_hidden=include_hidden
|
||||
)
|
||||
return self._get_fields(include_parents=include_parents, include_hidden=include_hidden)
|
||||
|
||||
def _get_fields(
|
||||
self,
|
||||
forward=True,
|
||||
reverse=True,
|
||||
include_parents=True,
|
||||
include_hidden=False,
|
||||
seen_models=None,
|
||||
):
|
||||
def _get_fields(self, forward=True, reverse=True, include_parents=True, include_hidden=False,
|
||||
seen_models=None):
|
||||
"""
|
||||
Internal helper function to return fields of the model.
|
||||
* If forward=True, then fields defined on this model are returned.
|
||||
@@ -870,9 +791,7 @@ class Options:
|
||||
parent chain to the model's concrete model.
|
||||
"""
|
||||
if include_parents not in (True, False, PROXY_PARENTS):
|
||||
raise TypeError(
|
||||
"Invalid argument for include_parents: %s" % (include_parents,)
|
||||
)
|
||||
raise TypeError("Invalid argument for include_parents: %s" % (include_parents,))
|
||||
# This helper function is used to allow recursion in ``get_fields()``
|
||||
# implementation and to provide a fast way for Django's internals to
|
||||
# access specific subsets of fields.
|
||||
@@ -904,22 +823,13 @@ class Options:
|
||||
# fields from the same parent again.
|
||||
if parent in seen_models:
|
||||
continue
|
||||
if (
|
||||
parent._meta.concrete_model != self.concrete_model
|
||||
and include_parents == PROXY_PARENTS
|
||||
):
|
||||
if (parent._meta.concrete_model != self.concrete_model and
|
||||
include_parents == PROXY_PARENTS):
|
||||
continue
|
||||
for obj in parent._meta._get_fields(
|
||||
forward=forward,
|
||||
reverse=reverse,
|
||||
include_parents=include_parents,
|
||||
include_hidden=include_hidden,
|
||||
seen_models=seen_models,
|
||||
):
|
||||
if (
|
||||
not getattr(obj, "parent_link", False)
|
||||
or obj.model == self.concrete_model
|
||||
):
|
||||
forward=forward, reverse=reverse, include_parents=include_parents,
|
||||
include_hidden=include_hidden, seen_models=seen_models):
|
||||
if not getattr(obj, 'parent_link', False) or obj.model == self.concrete_model:
|
||||
fields.append(obj)
|
||||
if reverse and not self.proxy:
|
||||
# Tree is computed once and cached until the app cache is expired.
|
||||
@@ -960,11 +870,7 @@ class Options:
|
||||
return [
|
||||
constraint
|
||||
for constraint in self.constraints
|
||||
if (
|
||||
isinstance(constraint, UniqueConstraint)
|
||||
and constraint.condition is None
|
||||
and not constraint.contains_expressions
|
||||
)
|
||||
if isinstance(constraint, UniqueConstraint) and constraint.condition is None
|
||||
]
|
||||
|
||||
@cached_property
|
||||
@@ -984,9 +890,6 @@ class Options:
|
||||
Fields to be returned after a database insert.
|
||||
"""
|
||||
return [
|
||||
field
|
||||
for field in self._get_fields(
|
||||
forward=True, reverse=False, include_parents=PROXY_PARENTS
|
||||
)
|
||||
if getattr(field, "db_returning", False)
|
||||
field for field in self._get_fields(forward=True, reverse=False, include_parents=PROXY_PARENTS)
|
||||
if getattr(field, 'db_returning', False)
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,19 +8,44 @@ circular import difficulties.
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.core.exceptions import FieldDoesNotExist, FieldError
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.utils import tree
|
||||
from django.utils.deprecation import RemovedInDjango40Warning
|
||||
|
||||
# PathInfo is used when converting lookups (fk__somecol). The contents
|
||||
# describe the relation in Model terms (model Options and Fields for both
|
||||
# sides of the relation. The join_field is the field backing the relation.
|
||||
PathInfo = namedtuple(
|
||||
"PathInfo",
|
||||
"from_opts to_opts target_fields join_field m2m direct filtered_relation",
|
||||
)
|
||||
PathInfo = namedtuple('PathInfo', 'from_opts to_opts target_fields join_field m2m direct filtered_relation')
|
||||
|
||||
|
||||
class InvalidQueryType(type):
|
||||
@property
|
||||
def _subclasses(self):
|
||||
return (FieldDoesNotExist, FieldError)
|
||||
|
||||
def __warn(self):
|
||||
warnings.warn(
|
||||
'The InvalidQuery exception class is deprecated. Use '
|
||||
'FieldDoesNotExist or FieldError instead.',
|
||||
category=RemovedInDjango40Warning,
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
self.__warn()
|
||||
return isinstance(instance, self._subclasses) or super().__instancecheck__(instance)
|
||||
|
||||
def __subclasscheck__(self, subclass):
|
||||
self.__warn()
|
||||
return issubclass(subclass, self._subclasses) or super().__subclasscheck__(subclass)
|
||||
|
||||
|
||||
class InvalidQuery(Exception, metaclass=InvalidQueryType):
|
||||
pass
|
||||
|
||||
|
||||
def subclasses(cls):
|
||||
@@ -34,26 +59,21 @@ class Q(tree.Node):
|
||||
Encapsulate filters as objects that can then be combined logically (using
|
||||
`&` and `|`).
|
||||
"""
|
||||
|
||||
# Connection types
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
AND = 'AND'
|
||||
OR = 'OR'
|
||||
default = AND
|
||||
conditional = True
|
||||
|
||||
def __init__(self, *args, _connector=None, _negated=False, **kwargs):
|
||||
super().__init__(
|
||||
children=[*args, *sorted(kwargs.items())],
|
||||
connector=_connector,
|
||||
negated=_negated,
|
||||
)
|
||||
super().__init__(children=[*args, *sorted(kwargs.items())], connector=_connector, negated=_negated)
|
||||
|
||||
def _combine(self, other, conn):
|
||||
if not (isinstance(other, Q) or getattr(other, "conditional", False) is True):
|
||||
if not(isinstance(other, Q) or getattr(other, 'conditional', False) is True):
|
||||
raise TypeError(other)
|
||||
|
||||
if not self:
|
||||
return other.copy() if hasattr(other, "copy") else copy.copy(other)
|
||||
return other.copy() if hasattr(other, 'copy') else copy.copy(other)
|
||||
elif isinstance(other, Q) and not other:
|
||||
_, args, kwargs = self.deconstruct()
|
||||
return type(self)(*args, **kwargs)
|
||||
@@ -76,31 +96,26 @@ class Q(tree.Node):
|
||||
obj.negate()
|
||||
return obj
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
# We must promote any new joins to left outer joins so that when Q is
|
||||
# used as an expression, rows aren't filtered due to joins.
|
||||
clause, joins = query._add_q(
|
||||
self,
|
||||
reuse,
|
||||
allow_joins=allow_joins,
|
||||
split_subq=False,
|
||||
self, reuse, allow_joins=allow_joins, split_subq=False,
|
||||
check_filterable=False,
|
||||
)
|
||||
query.promote_joins(joins)
|
||||
return clause
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
if path.startswith("django.db.models.query_utils"):
|
||||
path = path.replace("django.db.models.query_utils", "django.db.models")
|
||||
path = '%s.%s' % (self.__class__.__module__, self.__class__.__name__)
|
||||
if path.startswith('django.db.models.query_utils'):
|
||||
path = path.replace('django.db.models.query_utils', 'django.db.models')
|
||||
args = tuple(self.children)
|
||||
kwargs = {}
|
||||
if self.connector != self.default:
|
||||
kwargs["_connector"] = self.connector
|
||||
kwargs['_connector'] = self.connector
|
||||
if self.negated:
|
||||
kwargs["_negated"] = True
|
||||
kwargs['_negated'] = True
|
||||
return path, args, kwargs
|
||||
|
||||
|
||||
@@ -109,7 +124,6 @@ class DeferredAttribute:
|
||||
A wrapper for a deferred-loading field. When the value is read from this
|
||||
object the first time, the query is executed.
|
||||
"""
|
||||
|
||||
def __init__(self, field):
|
||||
self.field = field
|
||||
|
||||
@@ -146,6 +160,7 @@ class DeferredAttribute:
|
||||
|
||||
|
||||
class RegisterLookupMixin:
|
||||
|
||||
@classmethod
|
||||
def _get_lookup(cls, lookup_name):
|
||||
return cls.get_lookups().get(lookup_name, None)
|
||||
@@ -153,16 +168,13 @@ class RegisterLookupMixin:
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_lookups(cls):
|
||||
class_lookups = [
|
||||
parent.__dict__.get("class_lookups", {}) for parent in inspect.getmro(cls)
|
||||
]
|
||||
class_lookups = [parent.__dict__.get('class_lookups', {}) for parent in inspect.getmro(cls)]
|
||||
return cls.merge_dicts(class_lookups)
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
from django.db.models.lookups import Lookup
|
||||
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, "output_field"):
|
||||
if found is None and hasattr(self, 'output_field'):
|
||||
return self.output_field.get_lookup(lookup_name)
|
||||
if found is not None and not issubclass(found, Lookup):
|
||||
return None
|
||||
@@ -170,9 +182,8 @@ class RegisterLookupMixin:
|
||||
|
||||
def get_transform(self, lookup_name):
|
||||
from django.db.models.lookups import Transform
|
||||
|
||||
found = self._get_lookup(lookup_name)
|
||||
if found is None and hasattr(self, "output_field"):
|
||||
if found is None and hasattr(self, 'output_field'):
|
||||
return self.output_field.get_transform(lookup_name)
|
||||
if found is not None and not issubclass(found, Transform):
|
||||
return None
|
||||
@@ -198,7 +209,7 @@ class RegisterLookupMixin:
|
||||
def register_lookup(cls, lookup, lookup_name=None):
|
||||
if lookup_name is None:
|
||||
lookup_name = lookup.lookup_name
|
||||
if "class_lookups" not in cls.__dict__:
|
||||
if 'class_lookups' not in cls.__dict__:
|
||||
cls.class_lookups = {}
|
||||
cls.class_lookups[lookup_name] = lookup
|
||||
cls._clear_cached_lookups()
|
||||
@@ -245,8 +256,8 @@ def select_related_descend(field, restricted, requested, load_fields, reverse=Fa
|
||||
if field.attname not in load_fields:
|
||||
if restricted and field.name in requested:
|
||||
msg = (
|
||||
"Field %s.%s cannot be both deferred and traversed using "
|
||||
"select_related at the same time."
|
||||
'Field %s.%s cannot be both deferred and traversed using '
|
||||
'select_related at the same time.'
|
||||
) % (field.model._meta.object_name, field.name)
|
||||
raise FieldError(msg)
|
||||
return True
|
||||
@@ -272,14 +283,12 @@ def check_rel_lookup_compatibility(model, target_opts, field):
|
||||
1) model and opts match (where proxy inheritance is removed)
|
||||
2) model is parent of opts' model or the other way around
|
||||
"""
|
||||
|
||||
def check(opts):
|
||||
return (
|
||||
model._meta.concrete_model == opts.concrete_model
|
||||
or opts.concrete_model in model._meta.get_parent_list()
|
||||
or model in opts.get_parent_list()
|
||||
model._meta.concrete_model == opts.concrete_model or
|
||||
opts.concrete_model in model._meta.get_parent_list() or
|
||||
model in opts.get_parent_list()
|
||||
)
|
||||
|
||||
# If the field is a primary key, then doing a query against the field's
|
||||
# model is ok, too. Consider the case:
|
||||
# class Restaurant(models.Model):
|
||||
@@ -289,8 +298,9 @@ def check_rel_lookup_compatibility(model, target_opts, field):
|
||||
# give Place's opts as the target opts, but Restaurant isn't compatible
|
||||
# with that. This logic applies only to primary keys, as when doing __in=qs,
|
||||
# we are going to turn this into __in=qs.values('pk') later on.
|
||||
return check(target_opts) or (
|
||||
getattr(field, "primary_key", False) and check(field.model._meta)
|
||||
return (
|
||||
check(target_opts) or
|
||||
(getattr(field, 'primary_key', False) and check(field.model._meta))
|
||||
)
|
||||
|
||||
|
||||
@@ -299,11 +309,11 @@ class FilteredRelation:
|
||||
|
||||
def __init__(self, relation_name, *, condition=Q()):
|
||||
if not relation_name:
|
||||
raise ValueError("relation_name cannot be empty.")
|
||||
raise ValueError('relation_name cannot be empty.')
|
||||
self.relation_name = relation_name
|
||||
self.alias = None
|
||||
if not isinstance(condition, Q):
|
||||
raise ValueError("condition argument must be a Q() instance.")
|
||||
raise ValueError('condition argument must be a Q() instance.')
|
||||
self.condition = condition
|
||||
self.path = []
|
||||
|
||||
@@ -311,9 +321,9 @@ class FilteredRelation:
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.relation_name == other.relation_name
|
||||
and self.alias == other.alias
|
||||
and self.condition == other.condition
|
||||
self.relation_name == other.relation_name and
|
||||
self.alias == other.alias and
|
||||
self.condition == other.condition
|
||||
)
|
||||
|
||||
def clone(self):
|
||||
@@ -327,7 +337,7 @@ class FilteredRelation:
|
||||
QuerySet.annotate() only accepts expression-like arguments
|
||||
(with a resolve_expression() method).
|
||||
"""
|
||||
raise NotImplementedError("FilteredRelation.resolve_expression() is unused.")
|
||||
raise NotImplementedError('FilteredRelation.resolve_expression() is unused.')
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Resolve the condition in Join.filtered_relation.
|
||||
|
||||
@@ -11,7 +11,6 @@ class ModelSignal(Signal):
|
||||
Signal subclass that allows the sender to be lazily specified as a string
|
||||
of the `app_label.ModelName` form.
|
||||
"""
|
||||
|
||||
def _lazy_method(self, method, apps, receiver, sender, **kwargs):
|
||||
from django.db.models.options import Options
|
||||
|
||||
@@ -25,12 +24,8 @@ class ModelSignal(Signal):
|
||||
|
||||
def connect(self, receiver, sender=None, weak=True, dispatch_uid=None, apps=None):
|
||||
self._lazy_method(
|
||||
super().connect,
|
||||
apps,
|
||||
receiver,
|
||||
sender,
|
||||
weak=weak,
|
||||
dispatch_uid=dispatch_uid,
|
||||
super().connect, apps, receiver, sender,
|
||||
weak=weak, dispatch_uid=dispatch_uid,
|
||||
)
|
||||
|
||||
def disconnect(self, receiver=None, sender=None, dispatch_uid=None, apps=None):
|
||||
|
||||
@@ -3,4 +3,4 @@ from django.db.models.sql.query import Query
|
||||
from django.db.models.sql.subqueries import * # NOQA
|
||||
from django.db.models.sql.where import AND, OR
|
||||
|
||||
__all__ = ["Query", "AND", "OR"]
|
||||
__all__ = ['Query', 'AND', 'OR']
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Constants specific to the SQL storage portion of the ORM.
|
||||
"""
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
# Size of each "chunk" for get_iterator calls.
|
||||
# Larger values are slightly faster at the expense of more storage space.
|
||||
@@ -9,16 +10,17 @@ GET_ITERATOR_CHUNK_SIZE = 100
|
||||
# Namedtuples for sql.* internal use.
|
||||
|
||||
# How many results to expect from a cursor.execute call
|
||||
MULTI = "multi"
|
||||
SINGLE = "single"
|
||||
CURSOR = "cursor"
|
||||
NO_RESULTS = "no results"
|
||||
MULTI = 'multi'
|
||||
SINGLE = 'single'
|
||||
CURSOR = 'cursor'
|
||||
NO_RESULTS = 'no results'
|
||||
|
||||
ORDER_DIR = {
|
||||
"ASC": ("ASC", "DESC"),
|
||||
"DESC": ("DESC", "ASC"),
|
||||
'ASC': ('ASC', 'DESC'),
|
||||
'DESC': ('DESC', 'ASC'),
|
||||
}
|
||||
ORDER_PATTERN = _lazy_re_compile(r'[-+]?[.\w]+$')
|
||||
|
||||
# SQL join types.
|
||||
INNER = "INNER JOIN"
|
||||
LOUTER = "LEFT OUTER JOIN"
|
||||
INNER = 'INNER JOIN'
|
||||
LOUTER = 'LEFT OUTER JOIN'
|
||||
|
||||
@@ -11,7 +11,6 @@ class MultiJoin(Exception):
|
||||
multi-valued join was attempted (if the caller wants to treat that
|
||||
exceptionally).
|
||||
"""
|
||||
|
||||
def __init__(self, names_pos, path_with_names):
|
||||
self.level = names_pos
|
||||
# The path travelled, this includes the path to the multijoin.
|
||||
@@ -26,8 +25,7 @@ class Join:
|
||||
"""
|
||||
Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
|
||||
FROM entry. For example, the SQL generated could be
|
||||
LEFT OUTER JOIN "sometable" T1
|
||||
ON ("othertable"."sometable_id" = "sometable"."id")
|
||||
LEFT OUTER JOIN "sometable" T1 ON ("othertable"."sometable_id" = "sometable"."id")
|
||||
|
||||
This class is primarily used in Query.alias_map. All entries in alias_map
|
||||
must be Join compatible by providing the following attributes and methods:
|
||||
@@ -40,17 +38,8 @@ class Join:
|
||||
- as_sql()
|
||||
- relabeled_clone()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table_name,
|
||||
parent_alias,
|
||||
table_alias,
|
||||
join_type,
|
||||
join_field,
|
||||
nullable,
|
||||
filtered_relation=None,
|
||||
):
|
||||
def __init__(self, table_name, parent_alias, table_alias, join_type,
|
||||
join_field, nullable, filtered_relation=None):
|
||||
# Join table
|
||||
self.table_name = table_name
|
||||
self.parent_alias = parent_alias
|
||||
@@ -80,47 +69,36 @@ class Join:
|
||||
|
||||
# Add a join condition for each pair of joining columns.
|
||||
for lhs_col, rhs_col in self.join_cols:
|
||||
join_conditions.append(
|
||||
"%s.%s = %s.%s"
|
||||
% (
|
||||
qn(self.parent_alias),
|
||||
qn2(lhs_col),
|
||||
qn(self.table_alias),
|
||||
qn2(rhs_col),
|
||||
)
|
||||
)
|
||||
join_conditions.append('%s.%s = %s.%s' % (
|
||||
qn(self.parent_alias),
|
||||
qn2(lhs_col),
|
||||
qn(self.table_alias),
|
||||
qn2(rhs_col),
|
||||
))
|
||||
|
||||
# Add a single condition inside parentheses for whatever
|
||||
# get_extra_restriction() returns.
|
||||
extra_cond = self.join_field.get_extra_restriction(
|
||||
self.table_alias, self.parent_alias
|
||||
)
|
||||
compiler.query.where_class, self.table_alias, self.parent_alias)
|
||||
if extra_cond:
|
||||
extra_sql, extra_params = compiler.compile(extra_cond)
|
||||
join_conditions.append("(%s)" % extra_sql)
|
||||
join_conditions.append('(%s)' % extra_sql)
|
||||
params.extend(extra_params)
|
||||
if self.filtered_relation:
|
||||
extra_sql, extra_params = compiler.compile(self.filtered_relation)
|
||||
if extra_sql:
|
||||
join_conditions.append("(%s)" % extra_sql)
|
||||
join_conditions.append('(%s)' % extra_sql)
|
||||
params.extend(extra_params)
|
||||
if not join_conditions:
|
||||
# This might be a rel on the other end of an actual declared field.
|
||||
declared_field = getattr(self.join_field, "field", self.join_field)
|
||||
declared_field = getattr(self.join_field, 'field', self.join_field)
|
||||
raise ValueError(
|
||||
"Join generated an empty ON clause. %s did not yield either "
|
||||
"joining columns or extra restrictions." % declared_field.__class__
|
||||
)
|
||||
on_clause_sql = " AND ".join(join_conditions)
|
||||
alias_str = (
|
||||
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
|
||||
)
|
||||
sql = "%s %s%s ON (%s)" % (
|
||||
self.join_type,
|
||||
qn(self.table_name),
|
||||
alias_str,
|
||||
on_clause_sql,
|
||||
)
|
||||
on_clause_sql = ' AND '.join(join_conditions)
|
||||
alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
|
||||
sql = '%s %s%s ON (%s)' % (self.join_type, qn(self.table_name), alias_str, on_clause_sql)
|
||||
return sql, params
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
@@ -128,19 +106,12 @@ class Join:
|
||||
new_table_alias = change_map.get(self.table_alias, self.table_alias)
|
||||
if self.filtered_relation is not None:
|
||||
filtered_relation = self.filtered_relation.clone()
|
||||
filtered_relation.path = [
|
||||
change_map.get(p, p) for p in self.filtered_relation.path
|
||||
]
|
||||
filtered_relation.path = [change_map.get(p, p) for p in self.filtered_relation.path]
|
||||
else:
|
||||
filtered_relation = None
|
||||
return self.__class__(
|
||||
self.table_name,
|
||||
new_parent_alias,
|
||||
new_table_alias,
|
||||
self.join_type,
|
||||
self.join_field,
|
||||
self.nullable,
|
||||
filtered_relation=filtered_relation,
|
||||
self.table_name, new_parent_alias, new_table_alias, self.join_type,
|
||||
self.join_field, self.nullable, filtered_relation=filtered_relation,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -161,8 +132,9 @@ class Join:
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def equals(self, other):
|
||||
# Ignore filtered_relation in equality check.
|
||||
def equals(self, other, with_filtered_relation):
|
||||
if with_filtered_relation:
|
||||
return self == other
|
||||
return self.identity[:-1] == other.identity[:-1]
|
||||
|
||||
def demote(self):
|
||||
@@ -183,7 +155,6 @@ class BaseTable:
|
||||
SELECT * FROM "foo" WHERE somecond
|
||||
could be generated by this class.
|
||||
"""
|
||||
|
||||
join_type = None
|
||||
parent_alias = None
|
||||
filtered_relation = None
|
||||
@@ -193,16 +164,12 @@ class BaseTable:
|
||||
self.table_alias = alias
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
alias_str = (
|
||||
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
|
||||
)
|
||||
alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
|
||||
base_sql = compiler.quote_name_unless_alias(self.table_name)
|
||||
return base_sql + alias_str, []
|
||||
|
||||
def relabeled_clone(self, change_map):
|
||||
return self.__class__(
|
||||
self.table_name, change_map.get(self.table_alias, self.table_alias)
|
||||
)
|
||||
return self.__class__(self.table_name, change_map.get(self.table_alias, self.table_alias))
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
@@ -216,5 +183,5 @@ class BaseTable:
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def equals(self, other):
|
||||
def equals(self, other, with_filtered_relation):
|
||||
return self.identity == other.identity
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,16 +3,19 @@ Query subclasses which provide extra functionality beyond simple data retrieval.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql.constants import (
|
||||
CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS,
|
||||
)
|
||||
from django.db.models.sql.query import Query
|
||||
|
||||
__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
|
||||
__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'AggregateQuery']
|
||||
|
||||
|
||||
class DeleteQuery(Query):
|
||||
"""A DELETE SQL query."""
|
||||
|
||||
compiler = "SQLDeleteCompiler"
|
||||
compiler = 'SQLDeleteCompiler'
|
||||
|
||||
def do_query(self, table, where, using):
|
||||
self.alias_map = {table: self.alias_map[table]}
|
||||
@@ -34,21 +37,17 @@ class DeleteQuery(Query):
|
||||
num_deleted = 0
|
||||
field = self.get_meta().pk
|
||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||
self.clear_where()
|
||||
self.add_filter(
|
||||
f"{field.attname}__in",
|
||||
pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
|
||||
)
|
||||
num_deleted += self.do_query(
|
||||
self.get_meta().db_table, self.where, using=using
|
||||
)
|
||||
self.where = self.where_class()
|
||||
self.add_q(Q(
|
||||
**{field.attname + '__in': pk_list[offset:offset + GET_ITERATOR_CHUNK_SIZE]}))
|
||||
num_deleted += self.do_query(self.get_meta().db_table, self.where, using=using)
|
||||
return num_deleted
|
||||
|
||||
|
||||
class UpdateQuery(Query):
|
||||
"""An UPDATE SQL query."""
|
||||
|
||||
compiler = "SQLUpdateCompiler"
|
||||
compiler = 'SQLUpdateCompiler'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -71,10 +70,8 @@ class UpdateQuery(Query):
|
||||
def update_batch(self, pk_list, values, using):
|
||||
self.add_update_values(values)
|
||||
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
|
||||
self.clear_where()
|
||||
self.add_filter(
|
||||
"pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
|
||||
)
|
||||
self.where = self.where_class()
|
||||
self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]))
|
||||
self.get_compiler(using).execute_sql(NO_RESULTS)
|
||||
|
||||
def add_update_values(self, values):
|
||||
@@ -86,14 +83,12 @@ class UpdateQuery(Query):
|
||||
values_seq = []
|
||||
for name, val in values.items():
|
||||
field = self.get_meta().get_field(name)
|
||||
direct = (
|
||||
not (field.auto_created and not field.concrete) or not field.concrete
|
||||
)
|
||||
direct = not (field.auto_created and not field.concrete) or not field.concrete
|
||||
model = field.model._meta.concrete_model
|
||||
if not direct or (field.is_relation and field.many_to_many):
|
||||
raise FieldError(
|
||||
"Cannot update model field %r (only non-relations and "
|
||||
"foreign keys permitted)." % field
|
||||
'Cannot update model field %r (only non-relations and '
|
||||
'foreign keys permitted).' % field
|
||||
)
|
||||
if model is not self.get_meta().concrete_model:
|
||||
self.add_related_update(model, field, val)
|
||||
@@ -108,7 +103,7 @@ class UpdateQuery(Query):
|
||||
called add_update_targets() to hint at the extra information here.
|
||||
"""
|
||||
for field, model, val in values_seq:
|
||||
if hasattr(val, "resolve_expression"):
|
||||
if hasattr(val, 'resolve_expression'):
|
||||
# Resolve expressions here so that annotations are no longer needed
|
||||
val = val.resolve_expression(self, allow_joins=False, for_save=True)
|
||||
self.values.append((field, model, val))
|
||||
@@ -134,13 +129,13 @@ class UpdateQuery(Query):
|
||||
query = UpdateQuery(model)
|
||||
query.values = values
|
||||
if self.related_ids is not None:
|
||||
query.add_filter("pk__in", self.related_ids)
|
||||
query.add_filter(('pk__in', self.related_ids))
|
||||
result.append(query)
|
||||
return result
|
||||
|
||||
|
||||
class InsertQuery(Query):
|
||||
compiler = "SQLInsertCompiler"
|
||||
compiler = 'SQLInsertCompiler'
|
||||
|
||||
def __init__(self, *args, ignore_conflicts=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -160,7 +155,7 @@ class AggregateQuery(Query):
|
||||
elements in the provided list.
|
||||
"""
|
||||
|
||||
compiler = "SQLAggregateCompiler"
|
||||
compiler = 'SQLAggregateCompiler'
|
||||
|
||||
def __init__(self, model, inner_query):
|
||||
self.inner_query = inner_query
|
||||
|
||||
@@ -7,8 +7,8 @@ from django.utils import tree
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
# Connection types
|
||||
AND = "AND"
|
||||
OR = "OR"
|
||||
AND = 'AND'
|
||||
OR = 'OR'
|
||||
|
||||
|
||||
class WhereNode(tree.Node):
|
||||
@@ -25,7 +25,6 @@ class WhereNode(tree.Node):
|
||||
relabeled_clone() method or relabel_aliases() and clone() methods and
|
||||
contains_aggregate attribute.
|
||||
"""
|
||||
|
||||
default = AND
|
||||
resolved = False
|
||||
conditional = True
|
||||
@@ -41,15 +40,15 @@ class WhereNode(tree.Node):
|
||||
in_negated = negated ^ self.negated
|
||||
# If the effective connector is OR and this node contains an aggregate,
|
||||
# then we need to push the whole branch to HAVING clause.
|
||||
may_need_split = (in_negated and self.connector == AND) or (
|
||||
not in_negated and self.connector == OR
|
||||
)
|
||||
may_need_split = (
|
||||
(in_negated and self.connector == AND) or
|
||||
(not in_negated and self.connector == OR))
|
||||
if may_need_split and self.contains_aggregate:
|
||||
return None, self
|
||||
where_parts = []
|
||||
having_parts = []
|
||||
for c in self.children:
|
||||
if hasattr(c, "split_having"):
|
||||
if hasattr(c, 'split_having'):
|
||||
where_part, having_part = c.split_having(in_negated)
|
||||
if where_part is not None:
|
||||
where_parts.append(where_part)
|
||||
@@ -59,16 +58,8 @@ class WhereNode(tree.Node):
|
||||
having_parts.append(c)
|
||||
else:
|
||||
where_parts.append(c)
|
||||
having_node = (
|
||||
self.__class__(having_parts, self.connector, self.negated)
|
||||
if having_parts
|
||||
else None
|
||||
)
|
||||
where_node = (
|
||||
self.__class__(where_parts, self.connector, self.negated)
|
||||
if where_parts
|
||||
else None
|
||||
)
|
||||
having_node = self.__class__(having_parts, self.connector, self.negated) if having_parts else None
|
||||
where_node = self.__class__(where_parts, self.connector, self.negated) if where_parts else None
|
||||
return where_node, having_node
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
@@ -103,24 +94,24 @@ class WhereNode(tree.Node):
|
||||
# counts.
|
||||
if empty_needed == 0:
|
||||
if self.negated:
|
||||
return "", []
|
||||
return '', []
|
||||
else:
|
||||
raise EmptyResultSet
|
||||
if full_needed == 0:
|
||||
if self.negated:
|
||||
raise EmptyResultSet
|
||||
else:
|
||||
return "", []
|
||||
conn = " %s " % self.connector
|
||||
return '', []
|
||||
conn = ' %s ' % self.connector
|
||||
sql_string = conn.join(result)
|
||||
if sql_string:
|
||||
if self.negated:
|
||||
# Some backends (Oracle at least) need parentheses
|
||||
# around the inner SQL in the negated case, even if the
|
||||
# inner SQL contains just a single expression.
|
||||
sql_string = "NOT (%s)" % sql_string
|
||||
sql_string = 'NOT (%s)' % sql_string
|
||||
elif len(result) > 1 or self.resolved:
|
||||
sql_string = "(%s)" % sql_string
|
||||
sql_string = '(%s)' % sql_string
|
||||
return sql_string, result_params
|
||||
|
||||
def get_group_by_cols(self, alias=None):
|
||||
@@ -142,10 +133,10 @@ class WhereNode(tree.Node):
|
||||
mapping old (current) alias values to the new values.
|
||||
"""
|
||||
for pos, child in enumerate(self.children):
|
||||
if hasattr(child, "relabel_aliases"):
|
||||
if hasattr(child, 'relabel_aliases'):
|
||||
# For example another WhereNode
|
||||
child.relabel_aliases(change_map)
|
||||
elif hasattr(child, "relabeled_clone"):
|
||||
elif hasattr(child, 'relabeled_clone'):
|
||||
self.children[pos] = child.relabeled_clone(change_map)
|
||||
|
||||
def clone(self):
|
||||
@@ -155,12 +146,9 @@ class WhereNode(tree.Node):
|
||||
value) tuples, or objects supporting .clone().
|
||||
"""
|
||||
clone = self.__class__._new_instance(
|
||||
children=None,
|
||||
connector=self.connector,
|
||||
negated=self.negated,
|
||||
)
|
||||
children=[], connector=self.connector, negated=self.negated)
|
||||
for child in self.children:
|
||||
if hasattr(child, "clone"):
|
||||
if hasattr(child, 'clone'):
|
||||
clone.children.append(child.clone())
|
||||
else:
|
||||
clone.children.append(child)
|
||||
@@ -194,20 +182,24 @@ class WhereNode(tree.Node):
|
||||
def contains_over_clause(self):
|
||||
return self._contains_over_clause(self)
|
||||
|
||||
@property
|
||||
def is_summary(self):
|
||||
return any(child.is_summary for child in self.children)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_leaf(expr, query, *args, **kwargs):
|
||||
if hasattr(expr, "resolve_expression"):
|
||||
if hasattr(expr, 'resolve_expression'):
|
||||
expr = expr.resolve_expression(query, *args, **kwargs)
|
||||
return expr
|
||||
|
||||
@classmethod
|
||||
def _resolve_node(cls, node, query, *args, **kwargs):
|
||||
if hasattr(node, "children"):
|
||||
if hasattr(node, 'children'):
|
||||
for child in node.children:
|
||||
cls._resolve_node(child, query, *args, **kwargs)
|
||||
if hasattr(node, "lhs"):
|
||||
if hasattr(node, 'lhs'):
|
||||
node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
|
||||
if hasattr(node, "rhs"):
|
||||
if hasattr(node, 'rhs'):
|
||||
node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
@@ -216,30 +208,9 @@ class WhereNode(tree.Node):
|
||||
clone.resolved = True
|
||||
return clone
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
from django.db.models import BooleanField
|
||||
|
||||
return BooleanField()
|
||||
|
||||
def select_format(self, compiler, sql, params):
|
||||
# Wrap filters with a CASE WHEN expression if a database backend
|
||||
# (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
|
||||
# BY list.
|
||||
if not compiler.connection.features.supports_boolean_expr_in_select_clause:
|
||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||
return sql, params
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return self.output_field.get_db_converters(connection)
|
||||
|
||||
def get_lookup(self, lookup):
|
||||
return self.output_field.get_lookup(lookup)
|
||||
|
||||
|
||||
class NothingNode:
|
||||
"""A node that matches nothing."""
|
||||
|
||||
contains_aggregate = False
|
||||
|
||||
def as_sql(self, compiler=None, connection=None):
|
||||
@@ -268,7 +239,6 @@ class SubqueryConstraint:
|
||||
self.alias = alias
|
||||
self.columns = columns
|
||||
self.targets = targets
|
||||
query_object.clear_ordering(clear_default=True)
|
||||
self.query_object = query_object
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
|
||||
@@ -46,7 +46,7 @@ def create_namedtuple_class(*names):
|
||||
return unpickle_named_row, (names, tuple(self))
|
||||
|
||||
return type(
|
||||
"Row",
|
||||
(namedtuple("Row", names),),
|
||||
{"__reduce__": __reduce__, "__slots__": ()},
|
||||
'Row',
|
||||
(namedtuple('Row', names),),
|
||||
{'__reduce__': __reduce__, '__slots__': ()},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user