测试gitnore
This commit is contained in:
@@ -12,43 +12,38 @@ from django.utils.translation import gettext_lazy as _
|
||||
from ..utils import prefix_validation_error
|
||||
from .utils import AttributeSetter
|
||||
|
||||
__all__ = ["ArrayField"]
|
||||
__all__ = ['ArrayField']
|
||||
|
||||
|
||||
class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
default_error_messages = {
|
||||
"item_invalid": _("Item %(nth)s in the array did not validate:"),
|
||||
"nested_array_mismatch": _("Nested arrays must have the same length."),
|
||||
'item_invalid': _('Item %(nth)s in the array did not validate:'),
|
||||
'nested_array_mismatch': _('Nested arrays must have the same length.'),
|
||||
}
|
||||
_default_hint = ("list", "[]")
|
||||
_default_hint = ('list', '[]')
|
||||
|
||||
def __init__(self, base_field, size=None, **kwargs):
|
||||
self.base_field = base_field
|
||||
self.size = size
|
||||
if self.size:
|
||||
self.default_validators = [
|
||||
*self.default_validators,
|
||||
ArrayMaxLengthValidator(self.size),
|
||||
]
|
||||
self.default_validators = [*self.default_validators, ArrayMaxLengthValidator(self.size)]
|
||||
# For performance, only add a from_db_value() method if the base field
|
||||
# implements it.
|
||||
if hasattr(self.base_field, "from_db_value"):
|
||||
if hasattr(self.base_field, 'from_db_value'):
|
||||
self.from_db_value = self._from_db_value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
try:
|
||||
return self.__dict__["model"]
|
||||
return self.__dict__['model']
|
||||
except KeyError:
|
||||
raise AttributeError(
|
||||
"'%s' object has no attribute 'model'" % self.__class__.__name__
|
||||
)
|
||||
raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
|
||||
|
||||
@model.setter
|
||||
def model(self, model):
|
||||
self.__dict__["model"] = model
|
||||
self.__dict__['model'] = model
|
||||
self.base_field.model = model
|
||||
|
||||
@classmethod
|
||||
@@ -60,23 +55,21 @@ class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
if self.base_field.remote_field:
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"Base field for array cannot be a related field.",
|
||||
'Base field for array cannot be a related field.',
|
||||
obj=self,
|
||||
id="postgres.E002",
|
||||
id='postgres.E002'
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Remove the field name checks as they are not needed here.
|
||||
base_errors = self.base_field.check()
|
||||
if base_errors:
|
||||
messages = "\n ".join(
|
||||
"%s (%s)" % (error.msg, error.id) for error in base_errors
|
||||
)
|
||||
messages = '\n '.join('%s (%s)' % (error.msg, error.id) for error in base_errors)
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"Base field for array has errors:\n %s" % messages,
|
||||
'Base field for array has errors:\n %s' % messages,
|
||||
obj=self,
|
||||
id="postgres.E001",
|
||||
id='postgres.E001'
|
||||
)
|
||||
)
|
||||
return errors
|
||||
@@ -87,37 +80,32 @@ class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Array of %s" % self.base_field.description
|
||||
return 'Array of %s' % self.base_field.description
|
||||
|
||||
def db_type(self, connection):
|
||||
size = self.size or ""
|
||||
return "%s[%s]" % (self.base_field.db_type(connection), size)
|
||||
size = self.size or ''
|
||||
return '%s[%s]' % (self.base_field.db_type(connection), size)
|
||||
|
||||
def cast_db_type(self, connection):
|
||||
size = self.size or ""
|
||||
return "%s[%s]" % (self.base_field.cast_db_type(connection), size)
|
||||
size = self.size or ''
|
||||
return '%s[%s]' % (self.base_field.cast_db_type(connection), size)
|
||||
|
||||
def get_placeholder(self, value, compiler, connection):
|
||||
return "%s::{}".format(self.db_type(connection))
|
||||
return '%s::{}'.format(self.db_type(connection))
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [
|
||||
self.base_field.get_db_prep_value(i, connection, prepared=False)
|
||||
for i in value
|
||||
]
|
||||
return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
|
||||
return value
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if path == "django.contrib.postgres.fields.array.ArrayField":
|
||||
path = "django.contrib.postgres.fields.ArrayField"
|
||||
kwargs.update(
|
||||
{
|
||||
"base_field": self.base_field.clone(),
|
||||
"size": self.size,
|
||||
}
|
||||
)
|
||||
if path == 'django.contrib.postgres.fields.array.ArrayField':
|
||||
path = 'django.contrib.postgres.fields.ArrayField'
|
||||
kwargs.update({
|
||||
'base_field': self.base_field.clone(),
|
||||
'size': self.size,
|
||||
})
|
||||
return name, path, args, kwargs
|
||||
|
||||
def to_python(self, value):
|
||||
@@ -152,7 +140,7 @@ class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
if "_" not in name:
|
||||
if '_' not in name:
|
||||
try:
|
||||
index = int(name)
|
||||
except ValueError:
|
||||
@@ -161,7 +149,7 @@ class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
index += 1 # postgres uses 1-indexing
|
||||
return IndexTransformFactory(index, self.base_field)
|
||||
try:
|
||||
start, end = name.split("_")
|
||||
start, end = name.split('_')
|
||||
start = int(start) + 1
|
||||
end = int(end) # don't add one here because postgres slices are weird
|
||||
except ValueError:
|
||||
@@ -177,15 +165,15 @@ class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
except exceptions.ValidationError as error:
|
||||
raise prefix_validation_error(
|
||||
error,
|
||||
prefix=self.error_messages["item_invalid"],
|
||||
code="item_invalid",
|
||||
params={"nth": index + 1},
|
||||
prefix=self.error_messages['item_invalid'],
|
||||
code='item_invalid',
|
||||
params={'nth': index + 1},
|
||||
)
|
||||
if isinstance(self.base_field, ArrayField):
|
||||
if len({len(i) for i in value}) > 1:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["nested_array_mismatch"],
|
||||
code="nested_array_mismatch",
|
||||
self.error_messages['nested_array_mismatch'],
|
||||
code='nested_array_mismatch',
|
||||
)
|
||||
|
||||
def run_validators(self, value):
|
||||
@@ -196,20 +184,18 @@ class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
except exceptions.ValidationError as error:
|
||||
raise prefix_validation_error(
|
||||
error,
|
||||
prefix=self.error_messages["item_invalid"],
|
||||
code="item_invalid",
|
||||
params={"nth": index + 1},
|
||||
prefix=self.error_messages['item_invalid'],
|
||||
code='item_invalid',
|
||||
params={'nth': index + 1},
|
||||
)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": SimpleArrayField,
|
||||
"base_field": self.base_field.formfield(),
|
||||
"max_length": self.size,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
return super().formfield(**{
|
||||
'form_class': SimpleArrayField,
|
||||
'base_field': self.base_field.formfield(),
|
||||
'max_length': self.size,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
|
||||
class ArrayRHSMixin:
|
||||
@@ -217,21 +203,21 @@ class ArrayRHSMixin:
|
||||
if isinstance(rhs, (tuple, list)):
|
||||
expressions = []
|
||||
for value in rhs:
|
||||
if not hasattr(value, "resolve_expression"):
|
||||
if not hasattr(value, 'resolve_expression'):
|
||||
field = lhs.output_field
|
||||
value = Value(field.base_field.get_prep_value(value))
|
||||
expressions.append(value)
|
||||
rhs = Func(
|
||||
*expressions,
|
||||
function="ARRAY",
|
||||
template="%(function)s[%(expressions)s]",
|
||||
function='ARRAY',
|
||||
template='%(function)s[%(expressions)s]',
|
||||
)
|
||||
super().__init__(lhs, rhs)
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
cast_type = self.lhs.output_field.cast_db_type(connection)
|
||||
return "%s::%s" % (rhs, cast_type), rhs_params
|
||||
return '%s::%s' % (rhs, cast_type), rhs_params
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
@@ -256,29 +242,29 @@ class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayLenTransform(Transform):
|
||||
lookup_name = "len"
|
||||
lookup_name = 'len'
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
# Distinguish NULL and empty arrays
|
||||
return (
|
||||
"CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
|
||||
"coalesce(array_length(%(lhs)s, 1), 0) END"
|
||||
) % {"lhs": lhs}, params
|
||||
'CASE WHEN %(lhs)s IS NULL THEN NULL ELSE '
|
||||
'coalesce(array_length(%(lhs)s, 1), 0) END'
|
||||
) % {'lhs': lhs}, params
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayInLookup(In):
|
||||
def get_prep_lookup(self):
|
||||
values = super().get_prep_lookup()
|
||||
if hasattr(values, "resolve_expression"):
|
||||
if hasattr(values, 'resolve_expression'):
|
||||
return values
|
||||
# In.process_rhs() expects values to be hashable, so convert lists
|
||||
# to tuples.
|
||||
prepared_values = []
|
||||
for value in values:
|
||||
if hasattr(value, "resolve_expression"):
|
||||
if hasattr(value, 'resolve_expression'):
|
||||
prepared_values.append(value)
|
||||
else:
|
||||
prepared_values.append(tuple(value))
|
||||
@@ -286,6 +272,7 @@ class ArrayInLookup(In):
|
||||
|
||||
|
||||
class IndexTransform(Transform):
|
||||
|
||||
def __init__(self, index, base_field, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.index = index
|
||||
@@ -293,7 +280,7 @@ class IndexTransform(Transform):
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return "%s[%%s]" % lhs, params + [self.index]
|
||||
return '%s[%%s]' % lhs, params + [self.index]
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
@@ -301,6 +288,7 @@ class IndexTransform(Transform):
|
||||
|
||||
|
||||
class IndexTransformFactory:
|
||||
|
||||
def __init__(self, index, base_field):
|
||||
self.index = index
|
||||
self.base_field = base_field
|
||||
@@ -310,6 +298,7 @@ class IndexTransformFactory:
|
||||
|
||||
|
||||
class SliceTransform(Transform):
|
||||
|
||||
def __init__(self, start, end, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.start = start
|
||||
@@ -317,10 +306,11 @@ class SliceTransform(Transform):
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return "%s[%%s:%%s]" % lhs, params + [self.start, self.end]
|
||||
return '%s[%%s:%%s]' % lhs, params + [self.start, self.end]
|
||||
|
||||
|
||||
class SliceTransformFactory:
|
||||
|
||||
def __init__(self, start, end):
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from django.db.models import CharField, EmailField, TextField
|
||||
|
||||
__all__ = ["CICharField", "CIEmailField", "CIText", "CITextField"]
|
||||
__all__ = ['CICharField', 'CIEmailField', 'CIText', 'CITextField']
|
||||
|
||||
|
||||
class CIText:
|
||||
|
||||
def get_internal_type(self):
|
||||
return "CI" + super().get_internal_type()
|
||||
return 'CI' + super().get_internal_type()
|
||||
|
||||
def db_type(self, connection):
|
||||
return "citext"
|
||||
return 'citext'
|
||||
|
||||
|
||||
class CICharField(CIText, CharField):
|
||||
|
||||
@@ -7,19 +7,19 @@ from django.db.models import Field, TextField, Transform
|
||||
from django.db.models.fields.mixins import CheckFieldDefaultMixin
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
__all__ = ["HStoreField"]
|
||||
__all__ = ['HStoreField']
|
||||
|
||||
|
||||
class HStoreField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _("Map of strings to strings/nulls")
|
||||
description = _('Map of strings to strings/nulls')
|
||||
default_error_messages = {
|
||||
"not_a_string": _("The value of “%(key)s” is not a string or null."),
|
||||
'not_a_string': _('The value of “%(key)s” is not a string or null.'),
|
||||
}
|
||||
_default_hint = ("dict", "{}")
|
||||
_default_hint = ('dict', '{}')
|
||||
|
||||
def db_type(self, connection):
|
||||
return "hstore"
|
||||
return 'hstore'
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
@@ -32,9 +32,9 @@ class HStoreField(CheckFieldDefaultMixin, Field):
|
||||
for key, val in value.items():
|
||||
if not isinstance(val, str) and val is not None:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["not_a_string"],
|
||||
code="not_a_string",
|
||||
params={"key": key},
|
||||
self.error_messages['not_a_string'],
|
||||
code='not_a_string',
|
||||
params={'key': key},
|
||||
)
|
||||
|
||||
def to_python(self, value):
|
||||
@@ -46,12 +46,10 @@ class HStoreField(CheckFieldDefaultMixin, Field):
|
||||
return json.dumps(self.value_from_object(obj))
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.HStoreField,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
return super().formfield(**{
|
||||
'form_class': forms.HStoreField,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
@@ -87,10 +85,11 @@ class KeyTransform(Transform):
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return "(%s -> %%s)" % lhs, tuple(params) + (self.key_name,)
|
||||
return '(%s -> %%s)' % lhs, tuple(params) + (self.key_name,)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
@@ -100,13 +99,13 @@ class KeyTransformFactory:
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class KeysTransform(Transform):
|
||||
lookup_name = "keys"
|
||||
function = "akeys"
|
||||
lookup_name = 'keys'
|
||||
function = 'akeys'
|
||||
output_field = ArrayField(TextField())
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class ValuesTransform(Transform):
|
||||
lookup_name = "values"
|
||||
function = "avals"
|
||||
lookup_name = 'values'
|
||||
function = 'avals'
|
||||
output_field = ArrayField(TextField())
|
||||
|
||||
@@ -1,14 +1,43 @@
|
||||
from django.db.models import JSONField as BuiltinJSONField
|
||||
import warnings
|
||||
|
||||
__all__ = ["JSONField"]
|
||||
from django.db.models import JSONField as BuiltinJSONField
|
||||
from django.db.models.fields.json import (
|
||||
KeyTextTransform as BuiltinKeyTextTransform,
|
||||
KeyTransform as BuiltinKeyTransform,
|
||||
)
|
||||
from django.utils.deprecation import RemovedInDjango40Warning
|
||||
|
||||
__all__ = ['JSONField']
|
||||
|
||||
|
||||
class JSONField(BuiltinJSONField):
|
||||
system_check_removed_details = {
|
||||
"msg": (
|
||||
"django.contrib.postgres.fields.JSONField is removed except for "
|
||||
"support in historical migrations."
|
||||
system_check_deprecated_details = {
|
||||
'msg': (
|
||||
'django.contrib.postgres.fields.JSONField is deprecated. Support '
|
||||
'for it (except in historical migrations) will be removed in '
|
||||
'Django 4.0.'
|
||||
),
|
||||
"hint": "Use django.db.models.JSONField instead.",
|
||||
"id": "fields.E904",
|
||||
'hint': 'Use django.db.models.JSONField instead.',
|
||||
'id': 'fields.W904',
|
||||
}
|
||||
|
||||
|
||||
class KeyTransform(BuiltinKeyTransform):
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
'django.contrib.postgres.fields.jsonb.KeyTransform is deprecated '
|
||||
'in favor of django.db.models.fields.json.KeyTransform.',
|
||||
RemovedInDjango40Warning, stacklevel=2,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class KeyTextTransform(BuiltinKeyTextTransform):
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
'django.contrib.postgres.fields.jsonb.KeyTextTransform is '
|
||||
'deprecated in favor of '
|
||||
'django.db.models.fields.json.KeyTextTransform.',
|
||||
RemovedInDjango40Warning, stacklevel=2,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -10,23 +10,17 @@ from django.db.models.lookups import PostgresOperatorLookup
|
||||
from .utils import AttributeSetter
|
||||
|
||||
__all__ = [
|
||||
"RangeField",
|
||||
"IntegerRangeField",
|
||||
"BigIntegerRangeField",
|
||||
"DecimalRangeField",
|
||||
"DateTimeRangeField",
|
||||
"DateRangeField",
|
||||
"RangeBoundary",
|
||||
"RangeOperators",
|
||||
'RangeField', 'IntegerRangeField', 'BigIntegerRangeField',
|
||||
'DecimalRangeField', 'DateTimeRangeField', 'DateRangeField',
|
||||
'RangeBoundary', 'RangeOperators',
|
||||
]
|
||||
|
||||
|
||||
class RangeBoundary(models.Expression):
|
||||
"""A class that represents range boundaries."""
|
||||
|
||||
def __init__(self, inclusive_lower=True, inclusive_upper=False):
|
||||
self.lower = "[" if inclusive_lower else "("
|
||||
self.upper = "]" if inclusive_upper else ")"
|
||||
self.lower = '[' if inclusive_lower else '('
|
||||
self.upper = ']' if inclusive_upper else ')'
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return "'%s%s'" % (self.lower, self.upper), []
|
||||
@@ -34,40 +28,37 @@ class RangeBoundary(models.Expression):
|
||||
|
||||
class RangeOperators:
|
||||
# https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
|
||||
EQUAL = "="
|
||||
NOT_EQUAL = "<>"
|
||||
CONTAINS = "@>"
|
||||
CONTAINED_BY = "<@"
|
||||
OVERLAPS = "&&"
|
||||
FULLY_LT = "<<"
|
||||
FULLY_GT = ">>"
|
||||
NOT_LT = "&>"
|
||||
NOT_GT = "&<"
|
||||
ADJACENT_TO = "-|-"
|
||||
EQUAL = '='
|
||||
NOT_EQUAL = '<>'
|
||||
CONTAINS = '@>'
|
||||
CONTAINED_BY = '<@'
|
||||
OVERLAPS = '&&'
|
||||
FULLY_LT = '<<'
|
||||
FULLY_GT = '>>'
|
||||
NOT_LT = '&>'
|
||||
NOT_GT = '&<'
|
||||
ADJACENT_TO = '-|-'
|
||||
|
||||
|
||||
class RangeField(models.Field):
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Initializing base_field here ensures that its model matches the model
|
||||
# for self.
|
||||
if hasattr(self, "base_field"):
|
||||
# Initializing base_field here ensures that its model matches the model for self.
|
||||
if hasattr(self, 'base_field'):
|
||||
self.base_field = self.base_field()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
try:
|
||||
return self.__dict__["model"]
|
||||
return self.__dict__['model']
|
||||
except KeyError:
|
||||
raise AttributeError(
|
||||
"'%s' object has no attribute 'model'" % self.__class__.__name__
|
||||
)
|
||||
raise AttributeError("'%s' object has no attribute 'model'" % self.__class__.__name__)
|
||||
|
||||
@model.setter
|
||||
def model(self, model):
|
||||
self.__dict__["model"] = model
|
||||
self.__dict__['model'] = model
|
||||
self.base_field.model = model
|
||||
|
||||
@classmethod
|
||||
@@ -87,7 +78,7 @@ class RangeField(models.Field):
|
||||
if isinstance(value, str):
|
||||
# Assume we're deserializing
|
||||
vals = json.loads(value)
|
||||
for end in ("lower", "upper"):
|
||||
for end in ('lower', 'upper'):
|
||||
if end in vals:
|
||||
vals[end] = self.base_field.to_python(vals[end])
|
||||
value = self.range_type(**vals)
|
||||
@@ -107,7 +98,7 @@ class RangeField(models.Field):
|
||||
return json.dumps({"empty": True})
|
||||
base_field = self.base_field
|
||||
result = {"bounds": value._bounds}
|
||||
for end in ("lower", "upper"):
|
||||
for end in ('lower', 'upper'):
|
||||
val = getattr(value, end)
|
||||
if val is None:
|
||||
result[end] = None
|
||||
@@ -117,7 +108,7 @@ class RangeField(models.Field):
|
||||
return json.dumps(result)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
kwargs.setdefault("form_class", self.form_field)
|
||||
kwargs.setdefault('form_class', self.form_field)
|
||||
return super().formfield(**kwargs)
|
||||
|
||||
|
||||
@@ -127,7 +118,7 @@ class IntegerRangeField(RangeField):
|
||||
form_field = forms.IntegerRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "int4range"
|
||||
return 'int4range'
|
||||
|
||||
|
||||
class BigIntegerRangeField(RangeField):
|
||||
@@ -136,7 +127,7 @@ class BigIntegerRangeField(RangeField):
|
||||
form_field = forms.IntegerRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "int8range"
|
||||
return 'int8range'
|
||||
|
||||
|
||||
class DecimalRangeField(RangeField):
|
||||
@@ -145,7 +136,7 @@ class DecimalRangeField(RangeField):
|
||||
form_field = forms.DecimalRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "numrange"
|
||||
return 'numrange'
|
||||
|
||||
|
||||
class DateTimeRangeField(RangeField):
|
||||
@@ -154,7 +145,7 @@ class DateTimeRangeField(RangeField):
|
||||
form_field = forms.DateTimeRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "tstzrange"
|
||||
return 'tstzrange'
|
||||
|
||||
|
||||
class DateRangeField(RangeField):
|
||||
@@ -163,7 +154,7 @@ class DateRangeField(RangeField):
|
||||
form_field = forms.DateRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "daterange"
|
||||
return 'daterange'
|
||||
|
||||
|
||||
RangeField.register_lookup(lookups.DataContains)
|
||||
@@ -176,8 +167,7 @@ class DateTimeRangeContains(PostgresOperatorLookup):
|
||||
Lookup for Date/DateTimeRange containment to cast the rhs to the correct
|
||||
type.
|
||||
"""
|
||||
|
||||
lookup_name = "contains"
|
||||
lookup_name = 'contains'
|
||||
postgres_operator = RangeOperators.CONTAINS
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
@@ -190,19 +180,16 @@ class DateTimeRangeContains(PostgresOperatorLookup):
|
||||
def as_postgresql(self, compiler, connection):
|
||||
sql, params = super().as_postgresql(compiler, connection)
|
||||
# Cast the rhs if needed.
|
||||
cast_sql = ""
|
||||
cast_sql = ''
|
||||
if (
|
||||
isinstance(self.rhs, models.Expression)
|
||||
and self.rhs._output_field_or_none
|
||||
and
|
||||
isinstance(self.rhs, models.Expression) and
|
||||
self.rhs._output_field_or_none and
|
||||
# Skip cast if rhs has a matching range type.
|
||||
not isinstance(
|
||||
self.rhs._output_field_or_none, self.lhs.output_field.__class__
|
||||
)
|
||||
not isinstance(self.rhs._output_field_or_none, self.lhs.output_field.__class__)
|
||||
):
|
||||
cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
|
||||
cast_sql = "::{}".format(connection.data_types.get(cast_internal_type))
|
||||
return "%s%s" % (sql, cast_sql), params
|
||||
cast_sql = '::{}'.format(connection.data_types.get(cast_internal_type))
|
||||
return '%s%s' % (sql, cast_sql), params
|
||||
|
||||
|
||||
DateRangeField.register_lookup(DateTimeRangeContains)
|
||||
@@ -210,31 +197,31 @@ DateTimeRangeField.register_lookup(DateTimeRangeContains)
|
||||
|
||||
|
||||
class RangeContainedBy(PostgresOperatorLookup):
|
||||
lookup_name = "contained_by"
|
||||
lookup_name = 'contained_by'
|
||||
type_mapping = {
|
||||
"smallint": "int4range",
|
||||
"integer": "int4range",
|
||||
"bigint": "int8range",
|
||||
"double precision": "numrange",
|
||||
"numeric": "numrange",
|
||||
"date": "daterange",
|
||||
"timestamp with time zone": "tstzrange",
|
||||
'smallint': 'int4range',
|
||||
'integer': 'int4range',
|
||||
'bigint': 'int8range',
|
||||
'double precision': 'numrange',
|
||||
'numeric': 'numrange',
|
||||
'date': 'daterange',
|
||||
'timestamp with time zone': 'tstzrange',
|
||||
}
|
||||
postgres_operator = RangeOperators.CONTAINED_BY
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
# Ignore precision for DecimalFields.
|
||||
db_type = self.lhs.output_field.cast_db_type(connection).split("(")[0]
|
||||
db_type = self.lhs.output_field.cast_db_type(connection).split('(')[0]
|
||||
cast_type = self.type_mapping[db_type]
|
||||
return "%s::%s" % (rhs, cast_type), rhs_params
|
||||
return '%s::%s' % (rhs, cast_type), rhs_params
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if isinstance(self.lhs.output_field, models.FloatField):
|
||||
lhs = "%s::numeric" % lhs
|
||||
lhs = '%s::numeric' % lhs
|
||||
elif isinstance(self.lhs.output_field, models.SmallIntegerField):
|
||||
lhs = "%s::integer" % lhs
|
||||
lhs = '%s::integer' % lhs
|
||||
return lhs, lhs_params
|
||||
|
||||
def get_prep_lookup(self):
|
||||
@@ -250,38 +237,38 @@ models.DecimalField.register_lookup(RangeContainedBy)
|
||||
|
||||
@RangeField.register_lookup
|
||||
class FullyLessThan(PostgresOperatorLookup):
|
||||
lookup_name = "fully_lt"
|
||||
lookup_name = 'fully_lt'
|
||||
postgres_operator = RangeOperators.FULLY_LT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class FullGreaterThan(PostgresOperatorLookup):
|
||||
lookup_name = "fully_gt"
|
||||
lookup_name = 'fully_gt'
|
||||
postgres_operator = RangeOperators.FULLY_GT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class NotLessThan(PostgresOperatorLookup):
|
||||
lookup_name = "not_lt"
|
||||
lookup_name = 'not_lt'
|
||||
postgres_operator = RangeOperators.NOT_LT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class NotGreaterThan(PostgresOperatorLookup):
|
||||
lookup_name = "not_gt"
|
||||
lookup_name = 'not_gt'
|
||||
postgres_operator = RangeOperators.NOT_GT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class AdjacentToLookup(PostgresOperatorLookup):
|
||||
lookup_name = "adjacent_to"
|
||||
lookup_name = 'adjacent_to'
|
||||
postgres_operator = RangeOperators.ADJACENT_TO
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class RangeStartsWith(models.Transform):
|
||||
lookup_name = "startswith"
|
||||
function = "lower"
|
||||
lookup_name = 'startswith'
|
||||
function = 'lower'
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
@@ -290,8 +277,8 @@ class RangeStartsWith(models.Transform):
|
||||
|
||||
@RangeField.register_lookup
|
||||
class RangeEndsWith(models.Transform):
|
||||
lookup_name = "endswith"
|
||||
function = "upper"
|
||||
lookup_name = 'endswith'
|
||||
function = 'upper'
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
@@ -300,34 +287,34 @@ class RangeEndsWith(models.Transform):
|
||||
|
||||
@RangeField.register_lookup
|
||||
class IsEmpty(models.Transform):
|
||||
lookup_name = "isempty"
|
||||
function = "isempty"
|
||||
lookup_name = 'isempty'
|
||||
function = 'isempty'
|
||||
output_field = models.BooleanField()
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class LowerInclusive(models.Transform):
|
||||
lookup_name = "lower_inc"
|
||||
function = "LOWER_INC"
|
||||
lookup_name = 'lower_inc'
|
||||
function = 'LOWER_INC'
|
||||
output_field = models.BooleanField()
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class LowerInfinite(models.Transform):
|
||||
lookup_name = "lower_inf"
|
||||
function = "LOWER_INF"
|
||||
lookup_name = 'lower_inf'
|
||||
function = 'LOWER_INF'
|
||||
output_field = models.BooleanField()
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class UpperInclusive(models.Transform):
|
||||
lookup_name = "upper_inc"
|
||||
function = "UPPER_INC"
|
||||
lookup_name = 'upper_inc'
|
||||
function = 'UPPER_INC'
|
||||
output_field = models.BooleanField()
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class UpperInfinite(models.Transform):
|
||||
lookup_name = "upper_inf"
|
||||
function = "UPPER_INF"
|
||||
lookup_name = 'upper_inf'
|
||||
function = 'UPPER_INF'
|
||||
output_field = models.BooleanField()
|
||||
|
||||
Reference in New Issue
Block a user