测试gitnore

This commit is contained in:
ladeng07
2022-05-06 15:45:57 +08:00
parent 12f390949b
commit 51552904f9
2347 changed files with 120102 additions and 53549 deletions
File diff suppressed because it is too large Load Diff
@@ -24,7 +24,7 @@ class FieldFile(File):
def __eq__(self, other):
# Older code may be expecting FileField values to be simple strings.
# By overriding the == operator, it can remain backwards compatibility.
if hasattr(other, "name"):
if hasattr(other, 'name'):
return self.name == other.name
return self.name == other
@@ -37,14 +37,12 @@ class FieldFile(File):
def _require_file(self):
if not self:
raise ValueError(
"The '%s' attribute has no file associated with it." % self.field.name
)
raise ValueError("The '%s' attribute has no file associated with it." % self.field.name)
def _get_file(self):
self._require_file()
if getattr(self, "_file", None) is None:
self._file = self.storage.open(self.name, "rb")
if getattr(self, '_file', None) is None:
self._file = self.storage.open(self.name, 'rb')
return self._file
def _set_file(self, file):
@@ -72,14 +70,13 @@ class FieldFile(File):
return self.file.size
return self.storage.size(self.name)
def open(self, mode="rb"):
def open(self, mode='rb'):
self._require_file()
if getattr(self, "_file", None) is None:
if getattr(self, '_file', None) is None:
self.file = self.storage.open(self.name, mode)
else:
self.file.open(mode)
return self
# open() doesn't alter the file's contents, but it does reset the pointer
open.alters_data = True
@@ -96,7 +93,6 @@ class FieldFile(File):
# Save the object because it has changed, unless save is False
if save:
self.instance.save()
save.alters_data = True
def delete(self, save=True):
@@ -104,7 +100,7 @@ class FieldFile(File):
return
# Only close the file if it's already open, which we know by the
# presence of self._file
if hasattr(self, "_file"):
if hasattr(self, '_file'):
self.close()
del self.file
@@ -116,16 +112,15 @@ class FieldFile(File):
if save:
self.instance.save()
delete.alters_data = True
@property
def closed(self):
file = getattr(self, "_file", None)
file = getattr(self, '_file', None)
return file is None or file.closed
def close(self):
file = getattr(self, "_file", None)
file = getattr(self, '_file', None)
if file is not None:
file.close()
@@ -134,12 +129,12 @@ class FieldFile(File):
# the file's name. Everything else will be restored later, by
# FileDescriptor below.
return {
"name": self.name,
"closed": False,
"_committed": True,
"_file": None,
"instance": self.instance,
"field": self.field,
'name': self.name,
'closed': False,
'_committed': True,
'_file': None,
'instance': self.instance,
'field': self.field,
}
def __setstate__(self, state):
@@ -161,7 +156,6 @@ class FileDescriptor(DeferredAttribute):
>>> with open('/path/to/hello.world') as f:
... instance.file = File(f)
"""
def __get__(self, instance, cls=None):
if instance is None:
return self
@@ -204,7 +198,7 @@ class FileDescriptor(DeferredAttribute):
# Finally, because of the (some would say boneheaded) way pickle works,
# the underlying FieldFile might not actually itself have an associated
# file. So we need to reset the details of the FieldFile in those cases.
elif isinstance(file, FieldFile) and not hasattr(file, "field"):
elif isinstance(file, FieldFile) and not hasattr(file, 'field'):
file.instance = instance
file.field = self.field
file.storage = self.field.storage
@@ -231,10 +225,8 @@ class FileField(Field):
description = _("File")
def __init__(
self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
):
self._primary_key_set_explicitly = "primary_key" in kwargs
def __init__(self, verbose_name=None, name=None, upload_to='', storage=None, **kwargs):
self._primary_key_set_explicitly = 'primary_key' in kwargs
self.storage = storage or default_storage
if callable(self.storage):
@@ -244,15 +236,11 @@ class FileField(Field):
if not isinstance(self.storage, Storage):
raise TypeError(
"%s.storage must be a subclass/instance of %s.%s"
% (
self.__class__.__qualname__,
Storage.__module__,
Storage.__qualname__,
)
% (self.__class__.__qualname__, Storage.__module__, Storage.__qualname__)
)
self.upload_to = upload_to
kwargs.setdefault("max_length", 100)
kwargs.setdefault('max_length', 100)
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
@@ -266,24 +254,23 @@ class FileField(Field):
if self._primary_key_set_explicitly:
return [
checks.Error(
"'primary_key' is not a valid argument for a %s."
% self.__class__.__name__,
"'primary_key' is not a valid argument for a %s." % self.__class__.__name__,
obj=self,
id="fields.E201",
id='fields.E201',
)
]
else:
return []
def _check_upload_to(self):
if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
if isinstance(self.upload_to, str) and self.upload_to.startswith('/'):
return [
checks.Error(
"%s's 'upload_to' argument must be a relative path, not an "
"absolute path." % self.__class__.__name__,
obj=self,
id="fields.E202",
hint="Remove the leading slash.",
id='fields.E202',
hint='Remove the leading slash.',
)
]
else:
@@ -293,9 +280,9 @@ class FileField(Field):
name, path, args, kwargs = super().deconstruct()
if kwargs.get("max_length") == 100:
del kwargs["max_length"]
kwargs["upload_to"] = self.upload_to
kwargs['upload_to'] = self.upload_to
if self.storage is not default_storage:
kwargs["storage"] = getattr(self, "_storage_callable", self.storage)
kwargs['storage'] = getattr(self, '_storage_callable', self.storage)
return name, path, args, kwargs
def get_internal_type(self):
@@ -303,8 +290,7 @@ class FileField(Field):
def get_prep_value(self, value):
value = super().get_prep_value(value)
# Need to convert File objects provided via a form to string for
# database insertion.
# Need to convert File objects provided via a form to string for database insertion
if value is None:
return None
return str(value)
@@ -343,16 +329,14 @@ class FileField(Field):
if data is not None:
# This value will be converted to str and stored in the
# database, so leaving False as-is is not acceptable.
setattr(instance, self.name, data or "")
setattr(instance, self.name, data or '')
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.FileField,
"max_length": self.max_length,
**kwargs,
}
)
return super().formfield(**{
'form_class': forms.FileField,
'max_length': self.max_length,
**kwargs,
})
class ImageFileDescriptor(FileDescriptor):
@@ -360,7 +344,6 @@ class ImageFileDescriptor(FileDescriptor):
Just like the FileDescriptor, but for ImageFields. The only difference is
assigning the width/height to the width_field/height_field, if appropriate.
"""
def __set__(self, instance, value):
previous_file = instance.__dict__.get(self.field.attname)
super().__set__(instance, value)
@@ -381,7 +364,7 @@ class ImageFileDescriptor(FileDescriptor):
class ImageFieldFile(ImageFile, FieldFile):
def delete(self, save=True):
# Clear the image dimensions cache
if hasattr(self, "_dimensions_cache"):
if hasattr(self, '_dimensions_cache'):
del self._dimensions_cache
super().delete(save)
@@ -391,14 +374,7 @@ class ImageField(FileField):
descriptor_class = ImageFileDescriptor
description = _("Image")
def __init__(
self,
verbose_name=None,
name=None,
width_field=None,
height_field=None,
**kwargs,
):
def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs):
self.width_field, self.height_field = width_field, height_field
super().__init__(verbose_name, name, **kwargs)
@@ -414,13 +390,11 @@ class ImageField(FileField):
except ImportError:
return [
checks.Error(
"Cannot use ImageField because Pillow is not installed.",
hint=(
"Get Pillow at https://pypi.org/project/Pillow/ "
'or run command "python -m pip install Pillow".'
),
'Cannot use ImageField because Pillow is not installed.',
hint=('Get Pillow at https://pypi.org/project/Pillow/ '
'or run command "python -m pip install Pillow".'),
obj=self,
id="fields.E210",
id='fields.E210',
)
]
else:
@@ -429,9 +403,9 @@ class ImageField(FileField):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.width_field:
kwargs["width_field"] = self.width_field
kwargs['width_field'] = self.width_field
if self.height_field:
kwargs["height_field"] = self.height_field
kwargs['height_field'] = self.height_field
return name, path, args, kwargs
def contribute_to_class(self, cls, name, **kwargs):
@@ -471,9 +445,9 @@ class ImageField(FileField):
if not file and not force:
return
dimension_fields_filled = not (
(self.width_field and not getattr(instance, self.width_field))
or (self.height_field and not getattr(instance, self.height_field))
dimension_fields_filled = not(
(self.width_field and not getattr(instance, self.width_field)) or
(self.height_field and not getattr(instance, self.height_field))
)
# When both dimension fields have values, we are most likely loading
# data from the database or updating an image field that already had
@@ -501,9 +475,7 @@ class ImageField(FileField):
setattr(instance, self.height_field, height)
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.ImageField,
**kwargs,
}
)
return super().formfield(**{
'form_class': forms.ImageField,
**kwargs,
})
@@ -10,36 +10,32 @@ from django.utils.translation import gettext_lazy as _
from . import Field
from .mixins import CheckFieldDefaultMixin
__all__ = ["JSONField"]
__all__ = ['JSONField']
class JSONField(CheckFieldDefaultMixin, Field):
empty_strings_allowed = False
description = _("A JSON object")
description = _('A JSON object')
default_error_messages = {
"invalid": _("Value must be valid JSON."),
'invalid': _('Value must be valid JSON.'),
}
_default_hint = ("dict", "{}")
_default_hint = ('dict', '{}')
def __init__(
self,
verbose_name=None,
name=None,
encoder=None,
decoder=None,
self, verbose_name=None, name=None, encoder=None, decoder=None,
**kwargs,
):
if encoder and not callable(encoder):
raise ValueError("The encoder parameter must be a callable object.")
raise ValueError('The encoder parameter must be a callable object.')
if decoder and not callable(decoder):
raise ValueError("The decoder parameter must be a callable object.")
raise ValueError('The decoder parameter must be a callable object.')
self.encoder = encoder
self.decoder = decoder
super().__init__(verbose_name, name, **kwargs)
def check(self, **kwargs):
errors = super().check(**kwargs)
databases = kwargs.get("databases") or []
databases = kwargs.get('databases') or []
errors.extend(self._check_supported(databases))
return errors
@@ -50,19 +46,20 @@ class JSONField(CheckFieldDefaultMixin, Field):
continue
connection = connections[db]
if (
self.model._meta.required_db_vendor
and self.model._meta.required_db_vendor != connection.vendor
self.model._meta.required_db_vendor and
self.model._meta.required_db_vendor != connection.vendor
):
continue
if not (
"supports_json_field" in self.model._meta.required_db_features
or connection.features.supports_json_field
'supports_json_field' in self.model._meta.required_db_features or
connection.features.supports_json_field
):
errors.append(
checks.Error(
"%s does not support JSONFields." % connection.display_name,
'%s does not support JSONFields.'
% connection.display_name,
obj=self.model,
id="fields.E180",
id='fields.E180',
)
)
return errors
@@ -70,9 +67,9 @@ class JSONField(CheckFieldDefaultMixin, Field):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
if self.encoder is not None:
kwargs["encoder"] = self.encoder
kwargs['encoder'] = self.encoder
if self.decoder is not None:
kwargs["decoder"] = self.decoder
kwargs['decoder'] = self.decoder
return name, path, args, kwargs
def from_db_value(self, value, expression, connection):
@@ -88,7 +85,7 @@ class JSONField(CheckFieldDefaultMixin, Field):
return value
def get_internal_type(self):
return "JSONField"
return 'JSONField'
def get_prep_value(self, value):
if value is None:
@@ -107,66 +104,64 @@ class JSONField(CheckFieldDefaultMixin, Field):
json.dumps(value, cls=self.encoder)
except TypeError:
raise exceptions.ValidationError(
self.error_messages["invalid"],
code="invalid",
params={"value": value},
self.error_messages['invalid'],
code='invalid',
params={'value': value},
)
def value_to_string(self, obj):
return self.value_from_object(obj)
def formfield(self, **kwargs):
return super().formfield(
**{
"form_class": forms.JSONField,
"encoder": self.encoder,
"decoder": self.decoder,
**kwargs,
}
)
return super().formfield(**{
'form_class': forms.JSONField,
'encoder': self.encoder,
'decoder': self.decoder,
**kwargs,
})
def compile_json_path(key_transforms, include_root=True):
path = ["$"] if include_root else []
path = ['$'] if include_root else []
for key_transform in key_transforms:
try:
num = int(key_transform)
except ValueError: # non-integer
path.append(".")
path.append('.')
path.append(json.dumps(key_transform))
else:
path.append("[%s]" % num)
return "".join(path)
path.append('[%s]' % num)
return ''.join(path)
class DataContains(PostgresOperatorLookup):
lookup_name = "contains"
postgres_operator = "@>"
lookup_name = 'contains'
postgres_operator = '@>'
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
"contains lookup is not supported on this database backend."
'contains lookup is not supported on this database backend.'
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(lhs_params) + tuple(rhs_params)
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
return 'JSON_CONTAINS(%s, %s)' % (lhs, rhs), params
class ContainedBy(PostgresOperatorLookup):
lookup_name = "contained_by"
postgres_operator = "<@"
lookup_name = 'contained_by'
postgres_operator = '<@'
def as_sql(self, compiler, connection):
if not connection.features.supports_json_field_contains:
raise NotSupportedError(
"contained_by lookup is not supported on this database backend."
'contained_by lookup is not supported on this database backend.'
)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = tuple(rhs_params) + tuple(lhs_params)
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
return 'JSON_CONTAINS(%s, %s)' % (rhs, lhs), params
class HasKeyLookup(PostgresOperatorLookup):
@@ -175,13 +170,11 @@ class HasKeyLookup(PostgresOperatorLookup):
def as_sql(self, compiler, connection, template=None):
# Process JSON path from the left-hand side.
if isinstance(self.lhs, KeyTransform):
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
compiler, connection
)
lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(compiler, connection)
lhs_json_path = compile_json_path(lhs_key_transforms)
else:
lhs, lhs_params = self.process_lhs(compiler, connection)
lhs_json_path = "$"
lhs_json_path = '$'
sql = template % lhs
# Process JSON path from the right-hand side.
rhs = self.rhs
@@ -193,27 +186,20 @@ class HasKeyLookup(PostgresOperatorLookup):
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
else:
rhs_key_transforms = [key]
rhs_params.append(
"%s%s"
% (
lhs_json_path,
compile_json_path(rhs_key_transforms, include_root=False),
)
)
rhs_params.append('%s%s' % (
lhs_json_path,
compile_json_path(rhs_key_transforms, include_root=False),
))
# Add condition for each key.
if self.logical_operator:
sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
sql = '(%s)' % self.logical_operator.join([sql] * len(rhs_params))
return sql, tuple(lhs_params) + tuple(rhs_params)
def as_mysql(self, compiler, connection):
return self.as_sql(
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
)
return self.as_sql(compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)")
def as_oracle(self, compiler, connection):
sql, params = self.as_sql(
compiler, connection, template="JSON_EXISTS(%s, '%%s')"
)
sql, params = self.as_sql(compiler, connection, template="JSON_EXISTS(%s, '%%s')")
# Add paths directly into SQL because path expressions cannot be passed
# as bind variables on Oracle.
return sql % tuple(params), []
@@ -227,83 +213,64 @@ class HasKeyLookup(PostgresOperatorLookup):
return super().as_postgresql(compiler, connection)
def as_sqlite(self, compiler, connection):
return self.as_sql(
compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
)
return self.as_sql(compiler, connection, template='JSON_TYPE(%s, %%s) IS NOT NULL')
class HasKey(HasKeyLookup):
lookup_name = "has_key"
postgres_operator = "?"
lookup_name = 'has_key'
postgres_operator = '?'
prepare_rhs = False
class HasKeys(HasKeyLookup):
lookup_name = "has_keys"
postgres_operator = "?&"
logical_operator = " AND "
lookup_name = 'has_keys'
postgres_operator = '?&'
logical_operator = ' AND '
def get_prep_lookup(self):
return [str(item) for item in self.rhs]
class HasAnyKeys(HasKeys):
lookup_name = "has_any_keys"
postgres_operator = "?|"
logical_operator = " OR "
class CaseInsensitiveMixin:
"""
Mixin to allow case-insensitive comparison of JSON values on MySQL.
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
Because utf8mb4_bin is a binary collation, comparison of JSON values is
case-sensitive.
"""
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == "mysql":
return "LOWER(%s)" % lhs, lhs_params
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == "mysql":
return "LOWER(%s)" % rhs, rhs_params
return rhs, rhs_params
lookup_name = 'has_any_keys'
postgres_operator = '?|'
logical_operator = ' OR '
class JSONExact(lookups.Exact):
can_use_none_as_rhs = True
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == 'sqlite':
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs == '%s' and rhs_params == [None]:
# Use JSON_TYPE instead of JSON_EXTRACT for NULLs.
lhs = "JSON_TYPE(%s, '$')" % lhs
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
# Treat None lookup values as null.
if rhs == "%s" and rhs_params == [None]:
rhs_params = ["null"]
if connection.vendor == "mysql":
if rhs == '%s' and rhs_params == [None]:
rhs_params = ['null']
if connection.vendor == 'mysql':
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
rhs = rhs % tuple(func)
return rhs, rhs_params
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
pass
JSONField.register_lookup(DataContains)
JSONField.register_lookup(ContainedBy)
JSONField.register_lookup(HasKey)
JSONField.register_lookup(HasKeys)
JSONField.register_lookup(HasAnyKeys)
JSONField.register_lookup(JSONExact)
JSONField.register_lookup(JSONIContains)
class KeyTransform(Transform):
postgres_operator = "->"
postgres_nested_operator = "#>"
postgres_operator = '->'
postgres_nested_operator = '#>'
def __init__(self, key_name, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -316,50 +283,44 @@ class KeyTransform(Transform):
key_transforms.insert(0, previous.key_name)
previous = previous.lhs
lhs, params = compiler.compile(previous)
if connection.vendor == "oracle":
if connection.vendor == 'oracle':
# Escape string-formatting.
key_transforms = [key.replace("%", "%%") for key in key_transforms]
key_transforms = [key.replace('%', '%%') for key in key_transforms]
return lhs, params, key_transforms
def as_mysql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
def as_oracle(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
return (
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
% ((lhs, json_path) * 2)
"COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))" %
((lhs, json_path) * 2)
), tuple(params) * 2
def as_postgresql(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
if len(key_transforms) > 1:
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
sql = '(%s %s %%s)' % (lhs, self.postgres_nested_operator)
return sql, tuple(params) + (key_transforms,)
try:
lookup = int(self.key_name)
except ValueError:
lookup = self.key_name
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
return '(%s %s %%s)' % (lhs, self.postgres_operator), tuple(params) + (lookup,)
def as_sqlite(self, compiler, connection):
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
json_path = compile_json_path(key_transforms)
datatype_values = ",".join(
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
)
return (
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
return 'JSON_EXTRACT(%s, %%s)' % lhs, tuple(params) + (json_path,)
class KeyTextTransform(KeyTransform):
postgres_operator = "->>"
postgres_nested_operator = "#>>"
postgres_operator = '->>'
postgres_nested_operator = '#>>'
class KeyTransformTextLookupMixin:
@@ -369,21 +330,39 @@ class KeyTransformTextLookupMixin:
key values to text and performing the lookup on the resulting
representation.
"""
def __init__(self, key_transform, *args, **kwargs):
if not isinstance(key_transform, KeyTransform):
raise TypeError(
"Transform should be an instance of KeyTransform in order to "
"use this lookup."
'Transform should be an instance of KeyTransform in order to '
'use this lookup.'
)
key_text_transform = KeyTextTransform(
key_transform.key_name,
*key_transform.source_expressions,
key_transform.key_name, *key_transform.source_expressions,
**key_transform.extra,
)
super().__init__(key_text_transform, *args, **kwargs)
class CaseInsensitiveMixin:
"""
Mixin to allow case-insensitive comparison of JSON values on MySQL.
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
Because utf8mb4_bin is a binary collation, comparison of JSON values is
case-sensitive.
"""
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == 'mysql':
return 'LOWER(%s)' % lhs, lhs_params
return lhs, lhs_params
def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == 'mysql':
return 'LOWER(%s)' % rhs, rhs_params
return rhs, rhs_params
class KeyTransformIsNull(lookups.IsNull):
# key__isnull=False is the same as has_key='key'
def as_oracle(self, compiler, connection):
@@ -395,12 +374,12 @@ class KeyTransformIsNull(lookups.IsNull):
return sql, params
# Column doesn't have a key or IS NULL.
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
return '(NOT %s OR %s IS NULL)' % (sql, lhs), tuple(params) + tuple(lhs_params)
def as_sqlite(self, compiler, connection):
template = "JSON_TYPE(%s, %%s) IS NULL"
template = 'JSON_TYPE(%s, %%s) IS NULL'
if not self.rhs:
template = "JSON_TYPE(%s, %%s) IS NOT NULL"
template = 'JSON_TYPE(%s, %%s) IS NOT NULL'
return HasKey(self.lhs.lhs, self.lhs.key_name).as_sql(
compiler,
connection,
@@ -411,81 +390,75 @@ class KeyTransformIsNull(lookups.IsNull):
class KeyTransformIn(lookups.In):
def resolve_expression_parameter(self, compiler, connection, sql, param):
sql, params = super().resolve_expression_parameter(
compiler,
connection,
sql,
param,
compiler, connection, sql, param,
)
if (
not hasattr(param, "as_sql")
and not connection.features.has_native_json_field
not hasattr(param, 'as_sql') and
not connection.features.has_native_json_field
):
if connection.vendor == "oracle":
if connection.vendor == 'oracle':
value = json.loads(param)
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
if isinstance(value, (list, dict)):
sql = sql % "JSON_QUERY"
sql = sql % 'JSON_QUERY'
else:
sql = sql % "JSON_VALUE"
elif connection.vendor == "mysql" or (
connection.vendor == "sqlite"
and params[0] not in connection.ops.jsonfield_datatype_values
):
sql = sql % 'JSON_VALUE'
elif connection.vendor in {'sqlite', 'mysql'}:
sql = "JSON_EXTRACT(%s, '$')"
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
sql = "JSON_UNQUOTE(%s)" % sql
if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
sql = 'JSON_UNQUOTE(%s)' % sql
return sql, params
class KeyTransformExact(JSONExact):
def process_lhs(self, compiler, connection):
lhs, lhs_params = super().process_lhs(compiler, connection)
if connection.vendor == 'sqlite':
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs == '%s' and rhs_params == ['null']:
lhs, *_ = self.lhs.preprocess_lhs(compiler, connection)
lhs = 'JSON_TYPE(%s, %%s)' % lhs
return lhs, lhs_params
def process_rhs(self, compiler, connection):
if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection)
rhs, rhs_params = super().process_rhs(compiler, connection)
if connection.vendor == "oracle":
if connection.vendor == 'oracle':
func = []
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
for value in rhs_params:
value = json.loads(value)
if isinstance(value, (list, dict)):
func.append(sql % "JSON_QUERY")
func.append(sql % 'JSON_QUERY')
else:
func.append(sql % "JSON_VALUE")
func.append(sql % 'JSON_VALUE')
rhs = rhs % tuple(func)
elif connection.vendor == "sqlite":
func = []
for value in rhs_params:
if value in connection.ops.jsonfield_datatype_values:
func.append("%s")
else:
func.append("JSON_EXTRACT(%s, '$')")
elif connection.vendor == 'sqlite':
func = ["JSON_EXTRACT(%s, '$')" if value != 'null' else '%s' for value in rhs_params]
rhs = rhs % tuple(func)
return rhs, rhs_params
def as_oracle(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
if rhs_params == ["null"]:
if rhs_params == ['null']:
# Field has key and it's NULL.
has_key_expr = HasKey(self.lhs.lhs, self.lhs.key_name)
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
is_null_expr = self.lhs.get_lookup('isnull')(self.lhs, True)
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
return (
"%s AND %s" % (has_key_sql, is_null_sql),
'%s AND %s' % (has_key_sql, is_null_sql),
tuple(has_key_params) + tuple(is_null_params),
)
return super().as_sql(compiler, connection)
class KeyTransformIExact(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
):
class KeyTransformIExact(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact):
pass
class KeyTransformIContains(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
):
class KeyTransformIContains(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains):
pass
@@ -493,9 +466,7 @@ class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
pass
class KeyTransformIStartsWith(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
):
class KeyTransformIStartsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith):
pass
@@ -503,9 +474,7 @@ class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
pass
class KeyTransformIEndsWith(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
):
class KeyTransformIEndsWith(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith):
pass
@@ -513,9 +482,7 @@ class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
pass
class KeyTransformIRegex(
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
):
class KeyTransformIRegex(CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex):
pass
@@ -562,6 +529,7 @@ KeyTransform.register_lookup(KeyTransformGte)
class KeyTransformFactory:
def __init__(self, key_name):
self.key_name = key_name
@@ -29,25 +29,22 @@ class FieldCacheMixin:
class CheckFieldDefaultMixin:
_default_hint = ("<valid default>", "<invalid default>")
_default_hint = ('<valid default>', '<invalid default>')
def _check_default(self):
if (
self.has_default()
and self.default is not None
and not callable(self.default)
):
if self.has_default() and self.default is not None and not callable(self.default):
return [
checks.Warning(
"%s default should be a callable instead of an instance "
"so that it's not shared between all field instances."
% (self.__class__.__name__,),
"so that it's not shared between all field instances." % (
self.__class__.__name__,
),
hint=(
"Use a callable instead, e.g., use `%s` instead of "
"`%s`." % self._default_hint
'Use a callable instead, e.g., use `%s` instead of '
'`%s`.' % self._default_hint
),
obj=self,
id="fields.E010",
id='fields.E010',
)
]
else:
@@ -13,6 +13,6 @@ class OrderWrt(fields.IntegerField):
"""
def __init__(self, *args, **kwargs):
kwargs["name"] = "_order"
kwargs["editable"] = False
kwargs['name'] = '_order'
kwargs['editable'] = False
super().__init__(*args, **kwargs)
File diff suppressed because it is too large Load Diff
@@ -74,9 +74,7 @@ from django.utils.functional import cached_property
class ForeignKeyDeferredAttribute(DeferredAttribute):
def __set__(self, instance, value):
if instance.__dict__.get(self.field.attname) != value and self.field.is_cached(
instance
):
if instance.__dict__.get(self.field.attname) != value and self.field.is_cached(instance):
self.field.delete_cached_value(instance)
instance.__dict__[self.field.attname] = value
@@ -103,16 +101,14 @@ class ForwardManyToOneDescriptor:
# related model might not be resolved yet; `self.field.model` might
# still be a string model reference.
return type(
"RelatedObjectDoesNotExist",
(self.field.remote_field.model.DoesNotExist, AttributeError),
{
"__module__": self.field.model.__module__,
"__qualname__": "%s.%s.RelatedObjectDoesNotExist"
% (
'RelatedObjectDoesNotExist',
(self.field.remote_field.model.DoesNotExist, AttributeError), {
'__module__': self.field.model.__module__,
'__qualname__': '%s.%s.RelatedObjectDoesNotExist' % (
self.field.model.__qualname__,
self.field.name,
),
},
}
)
def is_cached(self, instance):
@@ -139,12 +135,9 @@ class ForwardManyToOneDescriptor:
# The check for len(...) == 1 is a special case that allows the query
# to be join-less and smaller. Refs #21760.
if remote_field.is_hidden() or len(self.field.foreign_related_fields) == 1:
query = {
"%s__in"
% related_field.name: {instance_attr(inst)[0] for inst in instances}
}
query = {'%s__in' % related_field.name: {instance_attr(inst)[0] for inst in instances}}
else:
query = {"%s__in" % self.field.related_query_name(): instances}
query = {'%s__in' % self.field.related_query_name(): instances}
queryset = queryset.filter(**query)
# Since we're going to assign directly in the cache,
@@ -153,14 +146,7 @@ class ForwardManyToOneDescriptor:
for rel_obj in queryset:
instance = instances_dict[rel_obj_attr(rel_obj)]
remote_field.set_cached_value(rel_obj, instance)
return (
queryset,
rel_obj_attr,
instance_attr,
True,
self.field.get_cache_name(),
False,
)
return queryset, rel_obj_attr, instance_attr, True, self.field.get_cache_name(), False
def get_object(self, instance):
qs = self.get_queryset(instance=instance)
@@ -187,11 +173,7 @@ class ForwardManyToOneDescriptor:
rel_obj = self.field.get_cached_value(instance)
except KeyError:
has_value = None not in self.field.get_local_related_value(instance)
ancestor_link = (
instance._meta.get_ancestor_link(self.field.model)
if has_value
else None
)
ancestor_link = instance._meta.get_ancestor_link(self.field.model) if has_value else None
if ancestor_link and ancestor_link.is_cached(instance):
# An ancestor link will exist if this field is defined on a
# multi-table inheritance parent of the instance's class.
@@ -229,12 +211,9 @@ class ForwardManyToOneDescriptor:
- ``value`` is the ``parent`` instance on the right of the equal sign
"""
# An object must be an instance of the related class.
if value is not None and not isinstance(
value, self.field.remote_field.model._meta.concrete_model
):
if value is not None and not isinstance(value, self.field.remote_field.model._meta.concrete_model):
raise ValueError(
'Cannot assign "%r": "%s.%s" must be a "%s" instance.'
% (
'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
value,
instance._meta.object_name,
self.field.name,
@@ -243,18 +222,11 @@ class ForwardManyToOneDescriptor:
)
elif value is not None:
if instance._state.db is None:
instance._state.db = router.db_for_write(
instance.__class__, instance=value
)
instance._state.db = router.db_for_write(instance.__class__, instance=value)
if value._state.db is None:
value._state.db = router.db_for_write(
value.__class__, instance=instance
)
value._state.db = router.db_for_write(value.__class__, instance=instance)
if not router.allow_relation(value, instance):
raise ValueError(
'Cannot assign "%r": the current database router prevents this '
"relation." % value
)
raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)
remote_field = self.field.remote_field
# If we're setting the value of a OneToOneField to None, we need to clear
@@ -342,15 +314,12 @@ class ForwardOneToOneDescriptor(ForwardManyToOneDescriptor):
opts = instance._meta
# Inherited primary key fields from this object's base classes.
inherited_pk_fields = [
field
for field in opts.concrete_fields
field for field in opts.concrete_fields
if field.primary_key and field.remote_field
]
for field in inherited_pk_fields:
rel_model_pk_name = field.remote_field.model._meta.pk.attname
raw_value = (
getattr(value, rel_model_pk_name) if value is not None else None
)
raw_value = getattr(value, rel_model_pk_name) if value is not None else None
setattr(instance, rel_model_pk_name, raw_value)
@@ -377,15 +346,13 @@ class ReverseOneToOneDescriptor:
# The exception isn't created at initialization time for the sake of
# consistency with `ForwardManyToOneDescriptor`.
return type(
"RelatedObjectDoesNotExist",
(self.related.related_model.DoesNotExist, AttributeError),
{
"__module__": self.related.model.__module__,
"__qualname__": "%s.%s.RelatedObjectDoesNotExist"
% (
'RelatedObjectDoesNotExist',
(self.related.related_model.DoesNotExist, AttributeError), {
'__module__': self.related.model.__module__,
'__qualname__': '%s.%s.RelatedObjectDoesNotExist' % (
self.related.model.__qualname__,
self.related.name,
),
)
},
)
@@ -403,7 +370,7 @@ class ReverseOneToOneDescriptor:
rel_obj_attr = self.related.field.get_local_related_value
instance_attr = self.related.field.get_foreign_related_value
instances_dict = {instance_attr(inst): inst for inst in instances}
query = {"%s__in" % self.related.field.name: instances}
query = {'%s__in' % self.related.field.name: instances}
queryset = queryset.filter(**query)
# Since we're going to assign directly in the cache,
@@ -411,14 +378,7 @@ class ReverseOneToOneDescriptor:
for rel_obj in queryset:
instance = instances_dict[rel_obj_attr(rel_obj)]
self.related.field.set_cached_value(rel_obj, instance)
return (
queryset,
rel_obj_attr,
instance_attr,
True,
self.related.get_cache_name(),
False,
)
return queryset, rel_obj_attr, instance_attr, True, self.related.get_cache_name(), False
def __get__(self, instance, cls=None):
"""
@@ -459,8 +419,10 @@ class ReverseOneToOneDescriptor:
if rel_obj is None:
raise self.RelatedObjectDoesNotExist(
"%s has no %s."
% (instance.__class__.__name__, self.related.get_accessor_name())
"%s has no %s." % (
instance.__class__.__name__,
self.related.get_accessor_name()
)
)
else:
return rel_obj
@@ -496,8 +458,7 @@ class ReverseOneToOneDescriptor:
elif not isinstance(value, self.related.related_model):
# An object must be an instance of the related class.
raise ValueError(
'Cannot assign "%r": "%s.%s" must be a "%s" instance.'
% (
'Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (
value,
instance._meta.object_name,
self.related.get_accessor_name(),
@@ -506,25 +467,14 @@ class ReverseOneToOneDescriptor:
)
else:
if instance._state.db is None:
instance._state.db = router.db_for_write(
instance.__class__, instance=value
)
instance._state.db = router.db_for_write(instance.__class__, instance=value)
if value._state.db is None:
value._state.db = router.db_for_write(
value.__class__, instance=instance
)
value._state.db = router.db_for_write(value.__class__, instance=instance)
if not router.allow_relation(value, instance):
raise ValueError(
'Cannot assign "%r": the current database router prevents this '
"relation." % value
)
raise ValueError('Cannot assign "%r": the current database router prevents this relation.' % value)
related_pk = tuple(
getattr(instance, field.attname)
for field in self.related.field.foreign_related_fields
)
# Set the value of the related field to the value of the related
# object's related field.
related_pk = tuple(getattr(instance, field.attname) for field in self.related.field.foreign_related_fields)
# Set the value of the related field to the value of the related object's related field
for index, field in enumerate(self.related.field.local_related_fields):
setattr(value, field.attname, related_pk[index])
@@ -587,13 +537,13 @@ class ReverseManyToOneDescriptor:
def _get_set_deprecation_msg_params(self):
return (
"reverse side of a related set",
'reverse side of a related set',
self.rel.get_accessor_name(),
)
def __set__(self, instance, value):
raise TypeError(
"Direct assignment to the %s is prohibited. Use %s.set() instead."
'Direct assignment to the %s is prohibited. Use %s.set() instead.'
% self._get_set_deprecation_msg_params(),
)
@@ -620,7 +570,6 @@ def create_reverse_many_to_one_manager(superclass, rel):
manager = getattr(self.model, manager)
manager_class = create_reverse_many_to_one_manager(manager.__class__, rel)
return manager_class(self.instance)
do_not_call_in_templates = True
def _apply_rel_filters(self, queryset):
@@ -628,9 +577,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
Filter the queryset for the instance this manager is bound to.
"""
db = self._db or router.db_for_read(self.model, instance=self.instance)
empty_strings_as_null = connections[
db
].features.interprets_empty_strings_as_nulls
empty_strings_as_null = connections[db].features.interprets_empty_strings_as_nulls
queryset._add_hints(instance=self.instance)
if self._db:
queryset = queryset.using(self._db)
@@ -638,7 +585,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
queryset = queryset.filter(**self.core_filters)
for field in self.field.foreign_related_fields:
val = getattr(self.instance, field.attname)
if val is None or (val == "" and empty_strings_as_null):
if val is None or (val == '' and empty_strings_as_null):
return queryset.none()
if self.field.many_to_one:
# Guard against field-like objects such as GenericRelation
@@ -650,34 +597,24 @@ def create_reverse_many_to_one_manager(superclass, rel):
except FieldError:
# The relationship has multiple target fields. Use a tuple
# for related object id.
rel_obj_id = tuple(
[
getattr(self.instance, target_field.attname)
for target_field in self.field.get_path_info()[
-1
].target_fields
]
)
rel_obj_id = tuple([
getattr(self.instance, target_field.attname)
for target_field in self.field.get_path_info()[-1].target_fields
])
else:
rel_obj_id = getattr(self.instance, target_field.attname)
queryset._known_related_objects = {
self.field: {rel_obj_id: self.instance}
}
queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
return queryset
def _remove_prefetched_objects(self):
try:
self.instance._prefetched_objects_cache.pop(
self.field.remote_field.get_cache_name()
)
self.instance._prefetched_objects_cache.pop(self.field.remote_field.get_cache_name())
except (AttributeError, KeyError):
pass # nothing to clear from cache
def get_queryset(self):
try:
return self.instance._prefetched_objects_cache[
self.field.remote_field.get_cache_name()
]
return self.instance._prefetched_objects_cache[self.field.remote_field.get_cache_name()]
except (AttributeError, KeyError):
queryset = super().get_queryset()
return self._apply_rel_filters(queryset)
@@ -692,7 +629,7 @@ def create_reverse_many_to_one_manager(superclass, rel):
rel_obj_attr = self.field.get_local_related_value
instance_attr = self.field.get_foreign_related_value
instances_dict = {instance_attr(inst): inst for inst in instances}
query = {"%s__in" % self.field.name: instances}
query = {'%s__in' % self.field.name: instances}
queryset = queryset.filter(**query)
# Since we just bypassed this class' get_queryset(), we must manage
@@ -709,13 +646,9 @@ def create_reverse_many_to_one_manager(superclass, rel):
def check_and_update_obj(obj):
if not isinstance(obj, self.model):
raise TypeError(
"'%s' instance expected, got %r"
% (
self.model._meta.object_name,
obj,
)
)
raise TypeError("'%s' instance expected, got %r" % (
self.model._meta.object_name, obj,
))
setattr(obj, self.field.name, self.instance)
if bulk:
@@ -728,44 +661,36 @@ def create_reverse_many_to_one_manager(superclass, rel):
"the object first." % obj
)
pks.append(obj.pk)
self.model._base_manager.using(db).filter(pk__in=pks).update(
**{
self.field.name: self.instance,
}
)
self.model._base_manager.using(db).filter(pk__in=pks).update(**{
self.field.name: self.instance,
})
else:
with transaction.atomic(using=db, savepoint=False):
for obj in objs:
check_and_update_obj(obj)
obj.save()
add.alters_data = True
def create(self, **kwargs):
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).create(**kwargs)
create.alters_data = True
def get_or_create(self, **kwargs):
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).get_or_create(**kwargs)
get_or_create.alters_data = True
def update_or_create(self, **kwargs):
kwargs[self.field.name] = self.instance
db = router.db_for_write(self.model, instance=self.instance)
return super(RelatedManager, self.db_manager(db)).update_or_create(**kwargs)
update_or_create.alters_data = True
# remove() and clear() are only provided if the ForeignKey can have a
# value of null.
# remove() and clear() are only provided if the ForeignKey can have a value of null.
if rel.field.null:
def remove(self, *objs, bulk=True):
if not objs:
return
@@ -773,13 +698,9 @@ def create_reverse_many_to_one_manager(superclass, rel):
old_ids = set()
for obj in objs:
if not isinstance(obj, self.model):
raise TypeError(
"'%s' instance expected, got %r"
% (
self.model._meta.object_name,
obj,
)
)
raise TypeError("'%s' instance expected, got %r" % (
self.model._meta.object_name, obj,
))
# Is obj actually part of this descriptor set?
if self.field.get_local_related_value(obj) == val:
old_ids.add(obj.pk)
@@ -788,12 +709,10 @@ def create_reverse_many_to_one_manager(superclass, rel):
"%r is not related to %r." % (obj, self.instance)
)
self._clear(self.filter(pk__in=old_ids), bulk)
remove.alters_data = True
def clear(self, *, bulk=True):
self._clear(self, bulk)
clear.alters_data = True
def _clear(self, queryset, bulk):
@@ -808,7 +727,6 @@ def create_reverse_many_to_one_manager(superclass, rel):
for obj in queryset:
setattr(obj, self.field.name, None)
obj.save(update_fields=[self.field.name])
_clear.alters_data = True
def set(self, objs, *, bulk=True, clear=False):
@@ -835,7 +753,6 @@ def create_reverse_many_to_one_manager(superclass, rel):
self.add(*new_objs, bulk=bulk)
else:
self.add(*objs, bulk=bulk)
set.alters_data = True
return RelatedManager
@@ -882,8 +799,7 @@ class ManyToManyDescriptor(ReverseManyToOneDescriptor):
def _get_set_deprecation_msg_params(self):
return (
"%s side of a many-to-many set"
% ("reverse" if self.reverse else "forward"),
'%s side of a many-to-many set' % ('reverse' if self.reverse else 'forward'),
self.rel.get_accessor_name() if self.reverse else self.field.name,
)
@@ -926,51 +842,42 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
self.core_filters = {}
self.pk_field_names = {}
for lh_field, rh_field in self.source_field.related_fields:
core_filter_key = "%s__%s" % (self.query_field_name, rh_field.name)
core_filter_key = '%s__%s' % (self.query_field_name, rh_field.name)
self.core_filters[core_filter_key] = getattr(instance, rh_field.attname)
self.pk_field_names[lh_field.name] = rh_field.name
self.related_val = self.source_field.get_foreign_related_value(instance)
if None in self.related_val:
raise ValueError(
'"%r" needs to have a value for field "%s" before '
"this many-to-many relationship can be used."
% (instance, self.pk_field_names[self.source_field_name])
)
raise ValueError('"%r" needs to have a value for field "%s" before '
'this many-to-many relationship can be used.' %
(instance, self.pk_field_names[self.source_field_name]))
# Even if this relation is not to pk, we require still pk value.
# The wish is that the instance has been already saved to DB,
# although having a pk value isn't a guarantee of that.
if instance.pk is None:
raise ValueError(
"%r instance needs to have a primary key value before "
"a many-to-many relationship can be used."
% instance.__class__.__name__
)
raise ValueError("%r instance needs to have a primary key value before "
"a many-to-many relationship can be used." %
instance.__class__.__name__)
def __call__(self, *, manager):
manager = getattr(self.model, manager)
manager_class = create_forward_many_to_many_manager(
manager.__class__, rel, reverse
)
manager_class = create_forward_many_to_many_manager(manager.__class__, rel, reverse)
return manager_class(instance=self.instance)
do_not_call_in_templates = True
def _build_remove_filters(self, removed_vals):
filters = Q((self.source_field_name, self.related_val))
filters = Q(**{self.source_field_name: self.related_val})
# No need to add a subquery condition if removed_vals is a QuerySet without
# filters.
removed_vals_filters = (
not isinstance(removed_vals, QuerySet) or removed_vals._has_filters()
)
removed_vals_filters = (not isinstance(removed_vals, QuerySet) or
removed_vals._has_filters())
if removed_vals_filters:
filters &= Q((f"{self.target_field_name}__in", removed_vals))
filters &= Q(**{'%s__in' % self.target_field_name: removed_vals})
if self.symmetrical:
symmetrical_filters = Q((self.target_field_name, self.related_val))
symmetrical_filters = Q(**{self.target_field_name: self.related_val})
if removed_vals_filters:
symmetrical_filters &= Q(
(f"{self.source_field_name}__in", removed_vals)
)
**{'%s__in' % self.source_field_name: removed_vals})
filters |= symmetrical_filters
return filters
@@ -1004,7 +911,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db)
query = {"%s__in" % self.query_field_name: instances}
query = {'%s__in' % self.query_field_name: instances}
queryset = queryset._next_is_sticky().filter(**query)
# M2M: need to annotate the query in order to get the primary model
@@ -1018,18 +925,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
join_table = fk.model._meta.db_table
connection = connections[queryset.db]
qn = connection.ops.quote_name
queryset = queryset.extra(
select={
"_prefetch_related_val_%s"
% f.attname: "%s.%s"
% (qn(join_table), qn(f.column))
for f in fk.local_related_fields
}
)
queryset = queryset.extra(select={
'_prefetch_related_val_%s' % f.attname:
'%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
return (
queryset,
lambda result: tuple(
getattr(result, "_prefetch_related_val_%s" % f.attname)
getattr(result, '_prefetch_related_val_%s' % f.attname)
for f in fk.local_related_fields
),
lambda inst: tuple(
@@ -1046,9 +948,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
db = router.db_for_write(self.through, instance=self.instance)
with transaction.atomic(using=db, savepoint=False):
self._add_items(
self.source_field_name,
self.target_field_name,
*objs,
self.source_field_name, self.target_field_name, *objs,
through_defaults=through_defaults,
)
# If this is a symmetrical m2m relation to self, add the mirror
@@ -1060,41 +960,30 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
*objs,
through_defaults=through_defaults,
)
add.alters_data = True
def remove(self, *objs):
self._remove_prefetched_objects()
self._remove_items(self.source_field_name, self.target_field_name, *objs)
remove.alters_data = True
def clear(self):
db = router.db_for_write(self.through, instance=self.instance)
with transaction.atomic(using=db, savepoint=False):
signals.m2m_changed.send(
sender=self.through,
action="pre_clear",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=None,
using=db,
sender=self.through, action="pre_clear",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db,
)
self._remove_prefetched_objects()
filters = self._build_remove_filters(super().get_queryset().using(db))
self.through._default_manager.using(db).filter(filters).delete()
signals.m2m_changed.send(
sender=self.through,
action="post_clear",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=None,
using=db,
sender=self.through, action="post_clear",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=None, using=db,
)
clear.alters_data = True
def set(self, objs, *, clear=False, through_defaults=None):
@@ -1108,11 +997,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
self.clear()
self.add(*objs, through_defaults=through_defaults)
else:
old_ids = set(
self.using(db).values_list(
self.target_field.target_field.attname, flat=True
)
)
old_ids = set(self.using(db).values_list(self.target_field.target_field.attname, flat=True))
new_objs = []
for obj in objs:
@@ -1128,7 +1013,6 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
self.remove(*old_ids)
self.add(*new_objs, through_defaults=through_defaults)
set.alters_data = True
def create(self, *, through_defaults=None, **kwargs):
@@ -1136,33 +1020,26 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
new_obj = super(ManyRelatedManager, self.db_manager(db)).create(**kwargs)
self.add(new_obj, through_defaults=through_defaults)
return new_obj
create.alters_data = True
def get_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(
**kwargs
)
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(**kwargs)
# We only need to add() if created because if we got an object back
# from get() then the relationship already exists.
if created:
self.add(obj, through_defaults=through_defaults)
return obj, created
get_or_create.alters_data = True
def update_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super(
ManyRelatedManager, self.db_manager(db)
).update_or_create(**kwargs)
obj, created = super(ManyRelatedManager, self.db_manager(db)).update_or_create(**kwargs)
# We only need to add() if created because if we got an object back
# from get() then the relationship already exists.
if created:
self.add(obj, through_defaults=through_defaults)
return obj, created
update_or_create.alters_data = True
def _get_target_ids(self, target_field_name, objs):
@@ -1170,7 +1047,6 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
Return the set of ids of `objs` that the target field references.
"""
from django.db.models import Model
target_ids = set()
target_field = self.through._meta.get_field(target_field_name)
for obj in objs:
@@ -1178,42 +1054,36 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
if not router.allow_relation(obj, self.instance):
raise ValueError(
'Cannot add "%r": instance is on database "%s", '
'value is on database "%s"'
% (obj, self.instance._state.db, obj._state.db)
'value is on database "%s"' %
(obj, self.instance._state.db, obj._state.db)
)
target_id = target_field.get_foreign_related_value(obj)[0]
if target_id is None:
raise ValueError(
'Cannot add "%r": the value for field "%s" is None'
% (obj, target_field_name)
'Cannot add "%r": the value for field "%s" is None' %
(obj, target_field_name)
)
target_ids.add(target_id)
elif isinstance(obj, Model):
raise TypeError(
"'%s' instance expected, got %r"
% (self.model._meta.object_name, obj)
"'%s' instance expected, got %r" %
(self.model._meta.object_name, obj)
)
else:
target_ids.add(target_field.get_prep_value(obj))
return target_ids
def _get_missing_target_ids(
self, source_field_name, target_field_name, db, target_ids
):
def _get_missing_target_ids(self, source_field_name, target_field_name, db, target_ids):
"""
Return the subset of ids of `objs` that aren't already assigned to
this relationship.
"""
vals = (
self.through._default_manager.using(db)
.values_list(target_field_name, flat=True)
.filter(
**{
source_field_name: self.related_val[0],
"%s__in" % target_field_name: target_ids,
}
)
)
vals = self.through._default_manager.using(db).values_list(
target_field_name, flat=True
).filter(**{
source_field_name: self.related_val[0],
'%s__in' % target_field_name: target_ids,
})
return target_ids.difference(vals)
def _get_add_plan(self, db, source_field_name):
@@ -1231,53 +1101,39 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
# user-defined intermediary models as they could have other fields
# causing conflicts which must be surfaced.
can_ignore_conflicts = (
connections[db].features.supports_ignore_conflicts
and self.through._meta.auto_created is not False
connections[db].features.supports_ignore_conflicts and
self.through._meta.auto_created is not False
)
# Don't send the signal when inserting duplicate data row
# for symmetrical reverse entries.
must_send_signals = (
self.reverse or source_field_name == self.source_field_name
) and (signals.m2m_changed.has_listeners(self.through))
must_send_signals = (self.reverse or source_field_name == self.source_field_name) and (
signals.m2m_changed.has_listeners(self.through)
)
# Fast addition through bulk insertion can only be performed
# if no m2m_changed listeners are connected for self.through
# as they require the added set of ids to be provided via
# pk_set.
return (
can_ignore_conflicts,
must_send_signals,
(can_ignore_conflicts and not must_send_signals),
)
return can_ignore_conflicts, must_send_signals, (can_ignore_conflicts and not must_send_signals)
def _add_items(
self, source_field_name, target_field_name, *objs, through_defaults=None
):
def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None):
# source_field_name: the PK fieldname in join table for the source object
# target_field_name: the PK fieldname in join table for the target object
# *objs - objects to add. Either object instances, or primary keys
# of object instances.
# *objs - objects to add. Either object instances, or primary keys of object instances.
if not objs:
return
through_defaults = dict(resolve_callables(through_defaults or {}))
target_ids = self._get_target_ids(target_field_name, objs)
db = router.db_for_write(self.through, instance=self.instance)
can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(
db, source_field_name
)
can_ignore_conflicts, must_send_signals, can_fast_add = self._get_add_plan(db, source_field_name)
if can_fast_add:
self.through._default_manager.using(db).bulk_create(
[
self.through(
**{
"%s_id" % source_field_name: self.related_val[0],
"%s_id" % target_field_name: target_id,
}
)
for target_id in target_ids
],
ignore_conflicts=True,
)
self.through._default_manager.using(db).bulk_create([
self.through(**{
'%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id,
})
for target_id in target_ids
], ignore_conflicts=True)
return
missing_target_ids = self._get_missing_target_ids(
@@ -1286,38 +1142,24 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
with transaction.atomic(using=db, savepoint=False):
if must_send_signals:
signals.m2m_changed.send(
sender=self.through,
action="pre_add",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=missing_target_ids,
using=db,
sender=self.through, action='pre_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=missing_target_ids, using=db,
)
# Add the ones that aren't there already.
self.through._default_manager.using(db).bulk_create(
[
self.through(
**through_defaults,
**{
"%s_id" % source_field_name: self.related_val[0],
"%s_id" % target_field_name: target_id,
},
)
for target_id in missing_target_ids
],
ignore_conflicts=can_ignore_conflicts,
)
self.through._default_manager.using(db).bulk_create([
self.through(**through_defaults, **{
'%s_id' % source_field_name: self.related_val[0],
'%s_id' % target_field_name: target_id,
})
for target_id in missing_target_ids
], ignore_conflicts=can_ignore_conflicts)
if must_send_signals:
signals.m2m_changed.send(
sender=self.through,
action="post_add",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=missing_target_ids,
using=db,
sender=self.through, action='post_add',
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=missing_target_ids, using=db,
)
def _remove_items(self, source_field_name, target_field_name, *objs):
@@ -1341,32 +1183,23 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
with transaction.atomic(using=db, savepoint=False):
# Send a signal to the other end if need be.
signals.m2m_changed.send(
sender=self.through,
action="pre_remove",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=old_ids,
using=db,
sender=self.through, action="pre_remove",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=old_ids, using=db,
)
target_model_qs = super().get_queryset()
if target_model_qs._has_filters():
old_vals = target_model_qs.using(db).filter(
**{"%s__in" % self.target_field.target_field.attname: old_ids}
)
old_vals = target_model_qs.using(db).filter(**{
'%s__in' % self.target_field.target_field.attname: old_ids})
else:
old_vals = old_ids
filters = self._build_remove_filters(old_vals)
self.through._default_manager.using(db).filter(filters).delete()
signals.m2m_changed.send(
sender=self.through,
action="post_remove",
instance=self.instance,
reverse=self.reverse,
model=self.model,
pk_set=old_ids,
using=db,
sender=self.through, action="post_remove",
instance=self.instance, reverse=self.reverse,
model=self.model, pk_set=old_ids, using=db,
)
return ManyRelatedManager
@@ -1,10 +1,5 @@
from django.db.models.lookups import (
Exact,
GreaterThan,
GreaterThanOrEqual,
In,
IsNull,
LessThan,
Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan,
LessThanOrEqual,
)
@@ -13,40 +8,29 @@ class MultiColSource:
contains_aggregate = False
def __init__(self, alias, targets, sources, field):
self.targets, self.sources, self.field, self.alias = (
targets,
sources,
field,
alias,
)
self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
self.output_field = self.field
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.alias, self.field)
return "{}({}, {})".format(
self.__class__.__name__, self.alias, self.field)
def relabeled_clone(self, relabels):
return self.__class__(
relabels.get(self.alias, self.alias), self.targets, self.sources, self.field
)
return self.__class__(relabels.get(self.alias, self.alias),
self.targets, self.sources, self.field)
def get_lookup(self, lookup):
return self.output_field.get_lookup(lookup)
def resolve_expression(self, *args, **kwargs):
return self
def get_normalized_value(value, lhs):
from django.db.models import Model
if isinstance(value, Model):
value_list = []
sources = lhs.output_field.get_path_info()[-1].target_fields
for source in sources:
while not isinstance(value, source.model) and source.remote_field:
source = source.remote_field.model._meta.get_field(
source.remote_field.field_name
)
source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
try:
value_list.append(getattr(value, source.attname))
except AttributeError:
@@ -68,26 +52,20 @@ class RelatedIn(In):
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if hasattr(self.lhs.output_field, "get_path_info"):
if hasattr(self.lhs.output_field, 'get_path_info'):
# Run the target field's get_prep_value. We can safely assume there is
# only one as we don't get to the direct value branch otherwise.
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[
-1
]
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
return super().get_prep_lookup()
def as_sql(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
# For multicolumn lookups we need to build a multicolumn where clause.
# This clause is either a SubqueryConstraint (for values that need
# to be compiled to SQL) or an OR-combined list of
# (col1 = val1 AND col2 = val2 AND ...) clauses.
# This clause is either a SubqueryConstraint (for values that need to be compiled to
# SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
from django.db.models.sql.where import (
AND,
OR,
SubqueryConstraint,
WhereNode,
AND, OR, SubqueryConstraint, WhereNode,
)
root_constraint = WhereNode(connector=OR)
@@ -95,35 +73,24 @@ class RelatedIn(In):
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
for value in values:
value_constraint = WhereNode()
for source, target, val in zip(
self.lhs.sources, self.lhs.targets, value
):
lookup_class = target.get_lookup("exact")
lookup = lookup_class(
target.get_col(self.lhs.alias, source), val
)
for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
lookup_class = target.get_lookup('exact')
lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
value_constraint.add(lookup, AND)
root_constraint.add(value_constraint, OR)
else:
root_constraint.add(
SubqueryConstraint(
self.lhs.alias,
[target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources],
self.rhs,
),
AND,
)
self.lhs.alias, [target.column for target in self.lhs.targets],
[source.name for source in self.lhs.sources], self.rhs),
AND)
return root_constraint.as_sql(compiler, connection)
else:
if not getattr(self.rhs, "has_select_fields", True) and not getattr(
self.lhs.field.target_field, "primary_key", False
):
if (not getattr(self.rhs, 'has_select_fields', True) and
not getattr(self.lhs.field.target_field, 'primary_key', False)):
self.rhs.clear_select_clause()
if (
getattr(self.lhs.output_field, "primary_key", False)
and self.lhs.output_field.model == self.rhs.model
):
if (getattr(self.lhs.output_field, 'primary_key', False) and
self.lhs.output_field.model == self.rhs.model):
# A case like Restaurant.objects.filter(place__in=restaurant_qs),
# where place is a OneToOneField and the primary key of
# Restaurant.
@@ -136,21 +103,17 @@ class RelatedIn(In):
class RelatedLookupMixin:
def get_prep_lookup(self):
if not isinstance(self.lhs, MultiColSource) and not hasattr(
self.rhs, "resolve_expression"
):
if not isinstance(self.lhs, MultiColSource) and not hasattr(self.rhs, 'resolve_expression'):
# If we get here, we are dealing with single-column relations.
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
# We need to run the related field's get_prep_value(). Consider case
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
# doesn't have validation for non-integers, so we must run validation
# using the target field.
if self.prepare_rhs and hasattr(self.lhs.output_field, "get_path_info"):
if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_path_info'):
# Get the target field. We can safely assume there is only one
# as we don't get to the direct value branch otherwise.
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[
-1
]
target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
self.rhs = target_field.get_prep_value(self.rhs)
return super().get_prep_lookup()
@@ -160,15 +123,11 @@ class RelatedLookupMixin:
assert self.rhs_is_direct_value()
self.rhs = get_normalized_value(self.rhs, self.lhs)
from django.db.models.sql.where import AND, WhereNode
root_constraint = WhereNode()
for target, source, val in zip(
self.lhs.targets, self.lhs.sources, self.rhs
):
for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
lookup_class = target.get_lookup(self.lookup_name)
root_constraint.add(
lookup_class(target.get_col(self.lhs.alias, source), val), AND
)
lookup_class(target.get_col(self.lhs.alias, source), val), AND)
return root_constraint.as_sql(compiler, connection)
return super().as_sql(compiler, connection)
@@ -36,16 +36,8 @@ class ForeignObjectRel(FieldCacheMixin):
null = True
empty_strings_allowed = False
def __init__(
self,
field,
to,
related_name=None,
related_query_name=None,
limit_choices_to=None,
parent_link=False,
on_delete=None,
):
def __init__(self, field, to, related_name=None, related_query_name=None,
limit_choices_to=None, parent_link=False, on_delete=None):
self.field = field
self.model = to
self.related_name = related_name
@@ -81,18 +73,14 @@ class ForeignObjectRel(FieldCacheMixin):
"""
target_fields = self.get_path_info()[-1].target_fields
if len(target_fields) > 1:
raise exceptions.FieldError(
"Can't use target_field for multicolumn relations."
)
raise exceptions.FieldError("Can't use target_field for multicolumn relations.")
return target_fields[0]
@cached_property
def related_model(self):
if not self.field.model:
raise AttributeError(
"This property can't be accessed before self.field.contribute_to_class "
"has been called."
)
"This property can't be accessed before self.field.contribute_to_class has been called.")
return self.field.model
@cached_property
@@ -122,7 +110,7 @@ class ForeignObjectRel(FieldCacheMixin):
return self.field.db_type
def __repr__(self):
return "<%s: %s.%s>" % (
return '<%s: %s.%s>' % (
type(self).__name__,
self.related_model._meta.app_label,
self.related_model._meta.model_name,
@@ -151,11 +139,8 @@ class ForeignObjectRel(FieldCacheMixin):
return hash(self.identity)
def get_choices(
self,
include_blank=True,
blank_choice=BLANK_CHOICE_DASH,
limit_choices_to=None,
ordering=(),
self, include_blank=True, blank_choice=BLANK_CHOICE_DASH,
limit_choices_to=None, ordering=(),
):
"""
Return choices with a default blank choices included, for use
@@ -168,17 +153,19 @@ class ForeignObjectRel(FieldCacheMixin):
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
if ordering:
qs = qs.order_by(*ordering)
return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
return (blank_choice if include_blank else []) + [
(x.pk, str(x)) for x in qs
]
def is_hidden(self):
"""Should the related object be hidden?"""
return bool(self.related_name) and self.related_name[-1] == "+"
return bool(self.related_name) and self.related_name[-1] == '+'
def get_joining_columns(self):
return self.field.get_reverse_joining_columns()
def get_extra_restriction(self, alias, related_alias):
return self.field.get_extra_restriction(related_alias, alias)
def get_extra_restriction(self, where_class, alias, related_alias):
return self.field.get_extra_restriction(where_class, related_alias, alias)
def set_field_name(self):
"""
@@ -200,13 +187,12 @@ class ForeignObjectRel(FieldCacheMixin):
opts = model._meta if model else self.related_model._meta
model = model or self.related_model
if self.multiple:
# If this is a symmetrical m2m relation on self, there is no
# reverse accessor.
# If this is a symmetrical m2m relation on self, there is no reverse accessor.
if self.symmetrical and model == self.model:
return None
if self.related_name:
return self.related_name
return opts.model_name + ("_set" if self.multiple else "")
return opts.model_name + ('_set' if self.multiple else '')
def get_path_info(self, filtered_relation=None):
return self.field.get_reverse_path_info(filtered_relation)
@@ -234,20 +220,10 @@ class ManyToOneRel(ForeignObjectRel):
reverse relations into actual fields.
"""
def __init__(
self,
field,
to,
field_name,
related_name=None,
related_query_name=None,
limit_choices_to=None,
parent_link=False,
on_delete=None,
):
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
limit_choices_to=None, parent_link=False, on_delete=None):
super().__init__(
field,
to,
field, to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -259,7 +235,7 @@ class ManyToOneRel(ForeignObjectRel):
def __getstate__(self):
state = self.__dict__.copy()
state.pop("related_model", None)
state.pop('related_model', None)
return state
@property
@@ -272,9 +248,7 @@ class ManyToOneRel(ForeignObjectRel):
"""
field = self.model._meta.get_field(self.field_name)
if not field.concrete:
raise exceptions.FieldDoesNotExist(
"No related field named '%s'" % self.field_name
)
raise exceptions.FieldDoesNotExist("No related field named '%s'" % self.field_name)
return field
def set_field_name(self):
@@ -289,21 +263,10 @@ class OneToOneRel(ManyToOneRel):
flags for the reverse relation.
"""
def __init__(
self,
field,
to,
field_name,
related_name=None,
related_query_name=None,
limit_choices_to=None,
parent_link=False,
on_delete=None,
):
def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
limit_choices_to=None, parent_link=False, on_delete=None):
super().__init__(
field,
to,
field_name,
field, to, field_name,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -322,21 +285,11 @@ class ManyToManyRel(ForeignObjectRel):
flags for the reverse relation.
"""
def __init__(
self,
field,
to,
related_name=None,
related_query_name=None,
limit_choices_to=None,
symmetrical=True,
through=None,
through_fields=None,
db_constraint=True,
):
def __init__(self, field, to, related_name=None, related_query_name=None,
limit_choices_to=None, symmetrical=True, through=None,
through_fields=None, db_constraint=True):
super().__init__(
field,
to,
field, to,
related_name=related_name,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
@@ -357,7 +310,7 @@ class ManyToManyRel(ForeignObjectRel):
def identity(self):
return super().identity + (
self.through,
make_hashable(self.through_fields),
self.through_fields,
self.db_constraint,
)
@@ -371,7 +324,7 @@ class ManyToManyRel(ForeignObjectRel):
field = opts.get_field(self.through_fields[0])
else:
for field in opts.fields:
rel = getattr(field, "remote_field", None)
rel = getattr(field, 'remote_field', None)
if rel and rel.model == self.model:
break
return field.foreign_related_fields[0]