测试gitnore

This commit is contained in:
ladeng07
2022-05-06 15:45:57 +08:00
parent 12f390949b
commit 51552904f9
2347 changed files with 120102 additions and 53549 deletions
@@ -11,10 +11,11 @@ from contextlib import contextmanager
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import DatabaseError as WrappedDatabaseError
from django.db import connections
from django.db import DatabaseError as WrappedDatabaseError, connections
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
from django.db.backends.utils import (
CursorDebugWrapper as BaseCursorDebugWrapper,
)
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.safestring import SafeString
@@ -29,17 +30,14 @@ except ImportError as e:
def psycopg2_version():
version = psycopg2.__version__.split(" ", 1)[0]
version = psycopg2.__version__.split(' ', 1)[0]
return get_version_tuple(version)
PSYCOPG2_VERSION = psycopg2_version()
if PSYCOPG2_VERSION < (2, 8, 4):
raise ImproperlyConfigured(
"psycopg2 version 2.8.4 or newer is required; you have %s"
% psycopg2.__version__
)
if PSYCOPG2_VERSION < (2, 5, 4):
raise ImproperlyConfigured("psycopg2_version 2.5.4 or newer is required; you have %s" % psycopg2.__version__)
# Some of these import psycopg2, so import them after checking if it's installed.
@@ -58,68 +56,69 @@ psycopg2.extras.register_uuid()
INETARRAY_OID = 1041
INETARRAY = psycopg2.extensions.new_array_type(
(INETARRAY_OID,),
"INETARRAY",
'INETARRAY',
psycopg2.extensions.UNICODE,
)
psycopg2.extensions.register_type(INETARRAY)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "postgresql"
display_name = "PostgreSQL"
vendor = 'postgresql'
display_name = 'PostgreSQL'
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
"AutoField": "serial",
"BigAutoField": "bigserial",
"BinaryField": "bytea",
"BooleanField": "boolean",
"CharField": "varchar(%(max_length)s)",
"DateField": "date",
"DateTimeField": "timestamp with time zone",
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
"DurationField": "interval",
"FileField": "varchar(%(max_length)s)",
"FilePathField": "varchar(%(max_length)s)",
"FloatField": "double precision",
"IntegerField": "integer",
"BigIntegerField": "bigint",
"IPAddressField": "inet",
"GenericIPAddressField": "inet",
"JSONField": "jsonb",
"OneToOneField": "integer",
"PositiveBigIntegerField": "bigint",
"PositiveIntegerField": "integer",
"PositiveSmallIntegerField": "smallint",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "smallserial",
"SmallIntegerField": "smallint",
"TextField": "text",
"TimeField": "time",
"UUIDField": "uuid",
'AutoField': 'serial',
'BigAutoField': 'bigserial',
'BinaryField': 'bytea',
'BooleanField': 'boolean',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'timestamp with time zone',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'interval',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'inet',
'GenericIPAddressField': 'inet',
'JSONField': 'jsonb',
'NullBooleanField': 'boolean',
'OneToOneField': 'integer',
'PositiveBigIntegerField': 'bigint',
'PositiveIntegerField': 'integer',
'PositiveSmallIntegerField': 'smallint',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'smallserial',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'UUIDField': 'uuid',
}
data_type_check_constraints = {
"PositiveBigIntegerField": '"%(column)s" >= 0',
"PositiveIntegerField": '"%(column)s" >= 0',
"PositiveSmallIntegerField": '"%(column)s" >= 0',
'PositiveBigIntegerField': '"%(column)s" >= 0',
'PositiveIntegerField': '"%(column)s" >= 0',
'PositiveSmallIntegerField': '"%(column)s" >= 0',
}
operators = {
"exact": "= %s",
"iexact": "= UPPER(%s)",
"contains": "LIKE %s",
"icontains": "LIKE UPPER(%s)",
"regex": "~ %s",
"iregex": "~* %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": "LIKE %s",
"endswith": "LIKE %s",
"istartswith": "LIKE UPPER(%s)",
"iendswith": "LIKE UPPER(%s)",
'exact': '= %s',
'iexact': '= UPPER(%s)',
'contains': 'LIKE %s',
'icontains': 'LIKE UPPER(%s)',
'regex': '~ %s',
'iregex': '~* %s',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': 'LIKE %s',
'endswith': 'LIKE %s',
'istartswith': 'LIKE UPPER(%s)',
'iendswith': 'LIKE UPPER(%s)',
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -130,16 +129,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = (
r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
)
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
pattern_ops = {
"contains": "LIKE '%%' || {} || '%%'",
"icontains": "LIKE '%%' || UPPER({}) || '%%'",
"startswith": "LIKE {} || '%%'",
"istartswith": "LIKE UPPER({}) || '%%'",
"endswith": "LIKE '%%' || {}",
"iendswith": "LIKE '%%' || UPPER({})",
'contains': "LIKE '%%' || {} || '%%'",
'icontains': "LIKE '%%' || UPPER({}) || '%%'",
'startswith': "LIKE {} || '%%'",
'istartswith': "LIKE UPPER({}) || '%%'",
'endswith': "LIKE '%%' || {}",
'iendswith': "LIKE '%%' || UPPER({})",
}
Database = Database
@@ -156,46 +153,33 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def get_connection_params(self):
settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db
if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
"service"
):
if settings_dict['NAME'] == '':
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME or OPTIONS['service'] value."
)
if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
"Please supply the NAME value.")
if len(settings_dict['NAME'] or '') > self.ops.max_name_length():
raise ImproperlyConfigured(
"The database name '%s' (%d characters) is longer than "
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
"in settings.DATABASES."
% (
settings_dict["NAME"],
len(settings_dict["NAME"]),
"in settings.DATABASES." % (
settings_dict['NAME'],
len(settings_dict['NAME']),
self.ops.max_name_length(),
)
)
conn_params = {}
if settings_dict["NAME"]:
conn_params = {
"database": settings_dict["NAME"],
**settings_dict["OPTIONS"],
}
elif settings_dict["NAME"] is None:
# Connect to the default 'postgres' db.
settings_dict.get("OPTIONS", {}).pop("service", None)
conn_params = {"database": "postgres", **settings_dict["OPTIONS"]}
else:
conn_params = {**settings_dict["OPTIONS"]}
conn_params.pop("isolation_level", None)
if settings_dict["USER"]:
conn_params["user"] = settings_dict["USER"]
if settings_dict["PASSWORD"]:
conn_params["password"] = settings_dict["PASSWORD"]
if settings_dict["HOST"]:
conn_params["host"] = settings_dict["HOST"]
if settings_dict["PORT"]:
conn_params["port"] = settings_dict["PORT"]
conn_params = {
'database': settings_dict['NAME'] or 'postgres',
**settings_dict['OPTIONS'],
}
conn_params.pop('isolation_level', None)
if settings_dict['USER']:
conn_params['user'] = settings_dict['USER']
if settings_dict['PASSWORD']:
conn_params['password'] = settings_dict['PASSWORD']
if settings_dict['HOST']:
conn_params['host'] = settings_dict['HOST']
if settings_dict['PORT']:
conn_params['port'] = settings_dict['PORT']
return conn_params
@async_unsafe
@@ -207,9 +191,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# default when no value is explicitly specified in options.
# - before calling _set_autocommit() because if autocommit is on, that
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
options = self.settings_dict["OPTIONS"]
options = self.settings_dict['OPTIONS']
try:
self.isolation_level = options["isolation_level"]
self.isolation_level = options['isolation_level']
except KeyError:
self.isolation_level = connection.isolation_level
else:
@@ -219,15 +203,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Register dummy loads() to avoid a round trip from psycopg2's decode
# to json.dumps() to json.loads(), when using a custom decoder in
# JSONField.
psycopg2.extras.register_default_jsonb(
conn_or_curs=connection, loads=lambda x: x
)
psycopg2.extras.register_default_jsonb(conn_or_curs=connection, loads=lambda x: x)
return connection
def ensure_timezone(self):
if self.connection is None:
return False
conn_timezone_name = self.connection.get_parameter_status("TimeZone")
conn_timezone_name = self.connection.get_parameter_status('TimeZone')
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
with self.connection.cursor() as cursor:
@@ -236,7 +218,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return False
def init_connection_state(self):
self.connection.set_client_encoding("UTF8")
self.connection.set_client_encoding('UTF8')
timezone_changed = self.ensure_timezone()
if timezone_changed:
@@ -249,9 +231,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if name:
# In autocommit mode, the cursor will be used outside of a
# transaction, hence use a holdable cursor.
cursor = self.connection.cursor(
name, scrollable=False, withhold=self.connection.autocommit
)
cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit)
else:
cursor = self.connection.cursor()
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
@@ -269,18 +249,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# For now, it's here so that every use of "threading" is
# also async-compatible.
try:
current_task = asyncio.current_task()
if hasattr(asyncio, 'current_task'):
# Python 3.7 and up
current_task = asyncio.current_task()
else:
# Python 3.6
current_task = asyncio.Task.current_task()
except RuntimeError:
current_task = None
# Current task can be none even if the current_task call didn't error
if current_task:
task_ident = str(id(current_task))
else:
task_ident = "sync"
task_ident = 'sync'
# Use that and the thread ident to get a unique name
return self._cursor(
name="_django_curs_%d_%s_%d"
% (
name='_django_curs_%d_%s_%d' % (
# Avoid reusing name in other threads / tasks
threading.current_thread().ident,
task_ident,
@@ -298,14 +282,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
afterward.
"""
with self.cursor() as cursor:
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.execute('SELECT 1')
except Database.Error:
return False
else:
@@ -313,31 +297,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@contextmanager
def _nodb_cursor(self):
cursor = None
try:
with super()._nodb_cursor() as cursor:
yield cursor
except (Database.DatabaseError, WrappedDatabaseError):
if cursor is not None:
raise
warnings.warn(
"Normally Django will use a connection to the 'postgres' database "
"to avoid running initialization queries against the production "
"database when it's not needed (for example, when running tests). "
"Django was unable to create a connection to the 'postgres' database "
"and will use the first PostgreSQL database instead.",
RuntimeWarning,
RuntimeWarning
)
for connection in connections.all():
if (
connection.vendor == "postgresql"
and connection.settings_dict["NAME"] != "postgres"
):
if connection.vendor == 'postgresql' and connection.settings_dict['NAME'] != 'postgres':
conn = self.__class__(
{
**self.settings_dict,
"NAME": connection.settings_dict["NAME"],
},
{**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
alias=self.alias,
)
try:
@@ -364,5 +339,5 @@ class CursorDebugWrapper(BaseCursorDebugWrapper):
return self.cursor.copy_expert(sql, file, *args)
def copy_to(self, file, table, *args, **kwargs):
with self.debug_sql(sql="COPY %s TO STDOUT" % table):
with self.debug_sql(sql='COPY %s TO STDOUT' % table):
return self.cursor.copy_to(file, table, *args, **kwargs)
@@ -4,53 +4,43 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "psql"
executable_name = 'psql'
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
options = settings_dict.get("OPTIONS", {})
options = settings_dict.get('OPTIONS', {})
host = settings_dict.get("HOST")
port = settings_dict.get("PORT")
dbname = settings_dict.get("NAME")
user = settings_dict.get("USER")
passwd = settings_dict.get("PASSWORD")
passfile = options.get("passfile")
service = options.get("service")
sslmode = options.get("sslmode")
sslrootcert = options.get("sslrootcert")
sslcert = options.get("sslcert")
sslkey = options.get("sslkey")
host = settings_dict.get('HOST')
port = settings_dict.get('PORT')
dbname = settings_dict.get('NAME') or 'postgres'
user = settings_dict.get('USER')
passwd = settings_dict.get('PASSWORD')
sslmode = options.get('sslmode')
sslrootcert = options.get('sslrootcert')
sslcert = options.get('sslcert')
sslkey = options.get('sslkey')
if not dbname and not service:
# Connect to the default 'postgres' db.
dbname = "postgres"
if user:
args += ["-U", user]
args += ['-U', user]
if host:
args += ["-h", host]
args += ['-h', host]
if port:
args += ["-p", str(port)]
if dbname:
args += [dbname]
args += ['-p', str(port)]
args += [dbname]
args.extend(parameters)
env = {}
if passwd:
env["PGPASSWORD"] = str(passwd)
if service:
env["PGSERVICE"] = str(service)
env['PGPASSWORD'] = str(passwd)
if sslmode:
env["PGSSLMODE"] = str(sslmode)
env['PGSSLMODE'] = str(sslmode)
if sslrootcert:
env["PGSSLROOTCERT"] = str(sslrootcert)
env['PGSSLROOTCERT'] = str(sslrootcert)
if sslcert:
env["PGSSLCERT"] = str(sslcert)
env['PGSSLCERT'] = str(sslcert)
if sslkey:
env["PGSSLKEY"] = str(sslkey)
if passfile:
env["PGPASSFILE"] = str(passfile)
env['PGSSLKEY'] = str(sslkey)
return args, (env or None)
def runshell(self, parameters):
@@ -2,12 +2,12 @@ import sys
from psycopg2 import errorcodes
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.backends.utils import strip_quotes
class DatabaseCreation(BaseDatabaseCreation):
def _quote_name(self, name):
return self.connection.ops.quote_name(name)
@@ -20,35 +20,30 @@ class DatabaseCreation(BaseDatabaseCreation):
return suffix and "WITH" + suffix
def sql_table_creation_suffix(self):
test_settings = self.connection.settings_dict["TEST"]
if test_settings.get("COLLATION") is not None:
raise ImproperlyConfigured(
"PostgreSQL does not support collation setting at database "
"creation time."
)
test_settings = self.connection.settings_dict['TEST']
assert test_settings['COLLATION'] is None, (
"PostgreSQL does not support collation setting at database creation time."
)
return self._get_database_create_suffix(
encoding=test_settings["CHARSET"],
template=test_settings.get("TEMPLATE"),
encoding=test_settings['CHARSET'],
template=test_settings.get('TEMPLATE'),
)
def _database_exists(self, cursor, database_name):
cursor.execute(
"SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s",
[strip_quotes(database_name)],
)
cursor.execute('SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s', [strip_quotes(database_name)])
return cursor.fetchone() is not None
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
if keepdb and self._database_exists(cursor, parameters["dbname"]):
if keepdb and self._database_exists(cursor, parameters['dbname']):
# If the database should be kept and it already exists, don't
# try to create a new one.
return
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
if getattr(e.__cause__, "pgcode", "") != errorcodes.DUPLICATE_DATABASE:
if getattr(e.__cause__, 'pgcode', '') != errorcodes.DUPLICATE_DATABASE:
# All errors except "database already exists" cancel tests.
self.log("Got an error creating the test database: %s" % e)
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
elif not keepdb:
# If the database should be kept, ignore "database already
@@ -60,11 +55,11 @@ class DatabaseCreation(BaseDatabaseCreation):
# to the template database.
self.connection.close()
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
test_db_params = {
"dbname": self._quote_name(target_database_name),
"suffix": self._get_database_create_suffix(template=source_database_name),
'dbname': self._quote_name(target_database_name),
'suffix': self._get_database_create_suffix(template=source_database_name),
}
with self._nodb_cursor() as cursor:
try:
@@ -72,16 +67,11 @@ class DatabaseCreation(BaseDatabaseCreation):
except Exception:
try:
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, target_database_name
),
)
)
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log("Got an error cloning the test database: %s" % e)
self.log('Got an error cloning the test database: %s' % e)
sys.exit(2)
@@ -53,32 +53,41 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_over_clause = True
only_supports_unbounded_with_preceding_and_following = True
supports_aggregate_filter_clause = True
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'}
validates_explain_options = False # A query will error on invalid options.
supports_deferrable_unique_constraints = True
has_json_operators = True
json_key_contains_list_matching_requires_list = True
test_collations = {
"non_default": "sv-x-icu",
"swedish_ci": "sv-x-icu",
}
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
django_test_skips = {
"opclasses are PostgreSQL only.": {
"indexes.tests.SchemaIndexesNotPostgreSQLTests."
"test_create_index_ignores_opclasses",
'opclasses are PostgreSQL only.': {
'indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses',
},
}
@cached_property
def test_collations(self):
# PostgreSQL < 10 doesn't support ICU collations.
if self.is_postgresql_10:
return {
'non_default': 'sv-x-icu',
'swedish_ci': 'sv-x-icu',
}
return {}
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
"PositiveBigIntegerField": "BigIntegerField",
"PositiveIntegerField": "IntegerField",
"PositiveSmallIntegerField": "SmallIntegerField",
'PositiveBigIntegerField': 'BigIntegerField',
'PositiveIntegerField': 'IntegerField',
'PositiveSmallIntegerField': 'SmallIntegerField',
}
@cached_property
def is_postgresql_10(self):
return self.connection.pg_version >= 100000
@cached_property
def is_postgresql_11(self):
return self.connection.pg_version >= 110000
@@ -91,9 +100,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
def is_postgresql_13(self):
return self.connection.pg_version >= 130000
has_websearch_to_tsquery = property(operator.attrgetter("is_postgresql_11"))
supports_covering_indexes = property(operator.attrgetter("is_postgresql_11"))
supports_covering_gist_indexes = property(operator.attrgetter("is_postgresql_12"))
supports_non_deterministic_collations = property(
operator.attrgetter("is_postgresql_12")
)
has_brin_autosummarize = property(operator.attrgetter('is_postgresql_10'))
has_websearch_to_tsquery = property(operator.attrgetter('is_postgresql_11'))
supports_table_partitions = property(operator.attrgetter('is_postgresql_10'))
supports_covering_indexes = property(operator.attrgetter('is_postgresql_11'))
supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12'))
supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12'))
supports_alternate_collation_providers = property(operator.attrgetter('is_postgresql_10'))
@@ -1,7 +1,5 @@
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection,
FieldInfo,
TableInfo,
BaseDatabaseIntrospection, FieldInfo, TableInfo,
)
from django.db.models import Index
@@ -9,66 +7,55 @@ from django.db.models import Index
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
16: "BooleanField",
17: "BinaryField",
20: "BigIntegerField",
21: "SmallIntegerField",
23: "IntegerField",
25: "TextField",
700: "FloatField",
701: "FloatField",
869: "GenericIPAddressField",
1042: "CharField", # blank-padded
1043: "CharField",
1082: "DateField",
1083: "TimeField",
1114: "DateTimeField",
1184: "DateTimeField",
1186: "DurationField",
1266: "TimeField",
1700: "DecimalField",
2950: "UUIDField",
3802: "JSONField",
16: 'BooleanField',
17: 'BinaryField',
20: 'BigIntegerField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
700: 'FloatField',
701: 'FloatField',
869: 'GenericIPAddressField',
1042: 'CharField', # blank-padded
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1186: 'DurationField',
1266: 'TimeField',
1700: 'DecimalField',
2950: 'UUIDField',
3802: 'JSONField',
}
# A hook for subclasses.
index_default_access_method = "btree"
index_default_access_method = 'btree'
ignored_tables = []
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.default and "nextval" in description.default:
if field_type == "IntegerField":
return "AutoField"
elif field_type == "BigIntegerField":
return "BigAutoField"
elif field_type == "SmallIntegerField":
return "SmallAutoField"
if description.default and 'nextval' in description.default:
if field_type == 'IntegerField':
return 'AutoField'
elif field_type == 'BigIntegerField':
return 'BigAutoField'
elif field_type == 'SmallIntegerField':
return 'SmallAutoField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute(
"""
SELECT
c.relname,
CASE
WHEN c.relispartition THEN 'p'
WHEN c.relkind IN ('m', 'v') THEN 'v'
ELSE 't'
END
cursor.execute("""
SELECT c.relname,
CASE WHEN {} THEN 'p' WHEN c.relkind IN ('m', 'v') THEN 'v' ELSE 't' END
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
"""
)
return [
TableInfo(*row)
for row in cursor.fetchall()
if row[0] not in self.ignored_tables
]
""".format('c.relispartition' if self.connection.features.supports_table_partitions else 'FALSE'))
return [TableInfo(*row) for row in cursor.fetchall() if row[0] not in self.ignored_tables]
def get_table_description(self, cursor, table_name):
"""
@@ -78,8 +65,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Query the pg_catalog tables as cursor.description does not reliably
# return the nullable property and information_schema.columns does not
# contain details of materialized views.
cursor.execute(
"""
cursor.execute("""
SELECT
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
@@ -95,13 +81,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
AND c.relname = %s
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
""",
[table_name],
)
""", [table_name])
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
cursor.execute(
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
)
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return [
FieldInfo(
line.name,
@@ -116,30 +98,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
]
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute(
"""
cursor.execute("""
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
JOIN
pg_depend d ON d.refobjid = s.oid
AND d.refclassid = 'pg_class'::regclass
JOIN
pg_attrdef ad ON ad.oid = d.objid
AND d.classid = 'pg_attrdef'::regclass
JOIN
pg_attribute col ON col.attrelid = ad.adrelid
AND col.attnum = ad.adnum
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid = 'pg_class'::regclass
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
JOIN pg_class tbl ON tbl.oid = ad.adrelid
WHERE s.relkind = 'S'
AND d.deptype in ('a', 'n')
AND pg_catalog.pg_table_is_visible(tbl.oid)
AND tbl.relname = %s
""",
[table_name],
)
""", [table_name])
return [
{"name": row[0], "table": table_name, "column": row[1]}
{'name': row[0], 'table': table_name, 'column': row[1]}
for row in cursor.fetchall()
]
@@ -148,29 +121,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
return {
row[0]: (row[2], row[1]) for row in self.get_key_columns(cursor, table_name)
}
return {row[0]: (row[2], row[1]) for row in self.get_key_columns(cursor, table_name)}
def get_key_columns(self, cursor, table_name):
cursor.execute(
"""
cursor.execute("""
SELECT a1.attname, c2.relname, a2.attname
FROM pg_constraint con
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
LEFT JOIN
pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
LEFT JOIN
pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
WHERE
c1.relname = %s AND
con.contype = 'f' AND
c1.relnamespace = c2.relnamespace AND
pg_catalog.pg_table_is_visible(c1.oid)
""",
[table_name],
)
""", [table_name])
return cursor.fetchall()
def get_constraints(self, cursor, table_name):
@@ -183,8 +149,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Loop over the key table, collecting things as constraints. The column
# array must return column names in the same order in which they were
# created.
cursor.execute(
"""
cursor.execute("""
SELECT
c.conname,
array(
@@ -203,9 +168,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
FROM pg_constraint AS c
JOIN pg_class AS cl ON c.conrelid = cl.oid
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
""",
[table_name],
)
""", [table_name])
for constraint, columns, kind, used_cols, options in cursor.fetchall():
constraints[constraint] = {
"columns": columns,
@@ -218,17 +181,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"options": options,
}
# Now get indexes
cursor.execute(
"""
cursor.execute("""
SELECT
indexname,
array_agg(attname ORDER BY arridx),
indisunique,
indisprimary,
array_agg(ordering ORDER BY arridx),
amname,
exprdef,
s2.attoptions
indexname, array_agg(attname ORDER BY arridx), indisunique, indisprimary,
array_agg(ordering ORDER BY arridx), amname, exprdef, s2.attoptions
FROM (
SELECT
c2.relname as indexname, idx.*, attr.attname, am.amname,
@@ -245,40 +201,23 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
c2.reloptions as attoptions
FROM (
SELECT *
FROM
pg_index i,
unnest(i.indkey, i.indoption)
WITH ORDINALITY koi(key, option, arridx)
FROM pg_index i, unnest(i.indkey, i.indoption) WITH ORDINALITY koi(key, option, arridx)
) idx
LEFT JOIN pg_class c ON idx.indrelid = c.oid
LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
LEFT JOIN pg_am am ON c2.relam = am.oid
LEFT JOIN
pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
LEFT JOIN pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
) s2
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
""",
[self.index_default_access_method, table_name],
)
for (
index,
columns,
unique,
primary,
orders,
type_,
definition,
options,
) in cursor.fetchall():
""", [self.index_default_access_method, table_name])
for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall():
if index not in constraints:
basic_index = (
type_ == self.index_default_access_method
and
type_ == self.index_default_access_method and
# '_btree' references
# django.contrib.postgres.indexes.BTreeIndex.suffix.
not index.endswith("_btree")
and options is None
not index.endswith('_btree') and options is None
)
constraints[index] = {
"columns": columns if columns != [None] else [],
@@ -2,38 +2,20 @@ from psycopg2.extras import Inet
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = "varchar"
explain_prefix = "EXPLAIN"
explain_options = frozenset(
[
"ANALYZE",
"BUFFERS",
"COSTS",
"SETTINGS",
"SUMMARY",
"TIMING",
"VERBOSE",
"WAL",
]
)
cast_char_field_without_max_length = 'varchar'
explain_prefix = 'EXPLAIN'
cast_data_types = {
"AutoField": "integer",
"BigAutoField": "bigint",
"SmallAutoField": "smallint",
'AutoField': 'integer',
'BigAutoField': 'bigint',
'SmallAutoField': 'smallint',
}
def unification_cast_sql(self, output_field):
internal_type = output_field.get_internal_type()
if internal_type in (
"GenericIPAddressField",
"IPAddressField",
"TimeField",
"UUIDField",
):
if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
# PostgreSQL will resolve a union as type 'text' if input types are
# 'unknown'.
# https://www.postgresql.org/docs/current/typeconv-union-case.html
@@ -41,19 +23,17 @@ class DatabaseOperations(BaseDatabaseOperations):
# PostgreSQL configuration so we need to explicitly cast them.
# We must also remove components of the type within brackets:
# varchar(255) -> varchar.
return (
"CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0]
)
return "%s"
return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
return '%s'
def date_extract_sql(self, lookup_type, field_name):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
if lookup_type == "week_day":
if lookup_type == 'week_day':
# For consistency across backends, we return Sunday=1, Saturday=7.
return "EXTRACT('dow' FROM %s) + 1" % field_name
elif lookup_type == "iso_week_day":
elif lookup_type == 'iso_week_day':
return "EXTRACT('isodow' FROM %s)" % field_name
elif lookup_type == "iso_year":
elif lookup_type == 'iso_year':
return "EXTRACT('isoyear' FROM %s)" % field_name
else:
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
@@ -64,27 +44,24 @@ class DatabaseOperations(BaseDatabaseOperations):
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
sign = "-" if sign == "+" else "+"
return f"{tzname}{sign}{offset}"
if '+' in tzname:
return tzname.replace('+', '-')
elif '-' in tzname:
return tzname.replace('-', '+')
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if tzname and settings.USE_TZ:
field_name = "%s AT TIME ZONE '%s'" % (
field_name,
self._prepare_tzname_delta(tzname),
)
field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "(%s)::date" % field_name
return '(%s)::date' % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "(%s)::time" % field_name
return '(%s)::time' % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
@@ -110,30 +87,21 @@ class DatabaseOperations(BaseDatabaseOperations):
return cursor.fetchall()
def lookup_cast(self, lookup_type, internal_type=None):
lookup = "%s"
lookup = '%s'
# Cast text lookups to text to allow things like filter(x__contains=4)
if lookup_type in (
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"regex",
"iregex",
):
if internal_type in ("IPAddressField", "GenericIPAddressField"):
if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
if internal_type in ('IPAddressField', 'GenericIPAddressField'):
lookup = "HOST(%s)"
elif internal_type in ("CICharField", "CIEmailField", "CITextField"):
lookup = "%s::citext"
elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'):
lookup = '%s::citext'
else:
lookup = "%s::text"
# Use UPPER(x) for case-insensitive lookups; it's faster.
if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
lookup = "UPPER(%s)" % lookup
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
lookup = 'UPPER(%s)' % lookup
return lookup
@@ -158,32 +126,29 @@ class DatabaseOperations(BaseDatabaseOperations):
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
# to truncate tables referenced by a foreign key in any other table.
sql_parts = [
style.SQL_KEYWORD("TRUNCATE"),
", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
style.SQL_KEYWORD('TRUNCATE'),
', '.join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
]
if reset_sequences:
sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY"))
sql_parts.append(style.SQL_KEYWORD('RESTART IDENTITY'))
if allow_cascade:
sql_parts.append(style.SQL_KEYWORD("CASCADE"))
return ["%s;" % " ".join(sql_parts)]
sql_parts.append(style.SQL_KEYWORD('CASCADE'))
return ['%s;' % ' '.join(sql_parts)]
def sequence_reset_by_name_sql(self, style, sequences):
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
# to reset sequence indices
sql = []
for sequence_info in sequences:
table_name = sequence_info["table"]
table_name = sequence_info['table']
# 'id' will be the case if it's an m2m using an autogenerated
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
column_name = sequence_info["column"] or "id"
sql.append(
"%s setval(pg_get_serial_sequence('%s','%s'), 1, false);"
% (
style.SQL_KEYWORD("SELECT"),
style.SQL_TABLE(self.quote_name(table_name)),
style.SQL_FIELD(column_name),
)
)
column_name = sequence_info['column'] or 'id'
sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % (
style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(self.quote_name(table_name)),
style.SQL_FIELD(column_name),
))
return sql
def tablespace_sql(self, tablespace, inline=False):
@@ -194,36 +159,31 @@ class DatabaseOperations(BaseDatabaseOperations):
def sequence_reset_sql(self, style, model_list):
from django.db import models
output = []
qn = self.quote_name
for model in model_list:
# Use `coalesce` to set the sequence for each model to the max pk
# value if there are records, or 1 if there are none. Set the
# `is_called` property (the third argument to `setval`) to true if
# there are records (as the max pk value is already in use),
# otherwise set it to false. Use pg_get_serial_sequence to get the
# underlying sequence name from the table name and column name.
# Use `coalesce` to set the sequence for each model to the max pk value if there are records,
# or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
# if there are records (as the max pk value is already in use), otherwise set it to false.
# Use pg_get_serial_sequence to get the underlying sequence name from the table name
# and column name (available since PostgreSQL 8)
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
output.append(
"%s setval(pg_get_serial_sequence('%s','%s'), "
"coalesce(max(%s), 1), max(%s) %s null) %s %s;"
% (
style.SQL_KEYWORD("SELECT"),
"coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(qn(model._meta.db_table)),
style.SQL_FIELD(f.column),
style.SQL_FIELD(qn(f.column)),
style.SQL_FIELD(qn(f.column)),
style.SQL_KEYWORD("IS NOT"),
style.SQL_KEYWORD("FROM"),
style.SQL_KEYWORD('IS NOT'),
style.SQL_KEYWORD('FROM'),
style.SQL_TABLE(qn(model._meta.db_table)),
)
)
# Only one AutoField is allowed per model, so don't bother
# continuing.
break
break # Only one AutoField is allowed per model, so don't bother continuing.
return output
def prep_for_iexact_query(self, x):
@@ -245,9 +205,9 @@ class DatabaseOperations(BaseDatabaseOperations):
def distinct_sql(self, fields, params):
if fields:
params = [param for param_list in params for param in param_list]
return (["DISTINCT ON (%s)" % ", ".join(fields)], params)
return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
else:
return ["DISTINCT"], []
return ['DISTINCT'], []
def last_executed_query(self, cursor, sql, params):
# https://www.psycopg.org/docs/cursor.html#cursor.query
@@ -258,16 +218,14 @@ class DatabaseOperations(BaseDatabaseOperations):
def return_insert_columns(self, fields):
if not fields:
return "", ()
return '', ()
columns = [
"%s.%s"
% (
'%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
) for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()
return 'RETURNING %s' % ', '.join(columns), ()
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
@@ -292,7 +250,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == "DateField":
if internal_type == 'DateField':
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
@@ -300,27 +258,18 @@ class DatabaseOperations(BaseDatabaseOperations):
return super().subtract_temporals(internal_type, lhs, rhs)
def explain_query_prefix(self, format=None, **options):
prefix = super().explain_query_prefix(format)
extra = {}
# Normalize options.
if options:
options = {
name.upper(): "true" if value else "false"
for name, value in options.items()
}
for valid_option in self.explain_options:
value = options.pop(valid_option, None)
if value is not None:
extra[valid_option.upper()] = value
prefix = super().explain_query_prefix(format, **options)
if format:
extra["FORMAT"] = format
extra['FORMAT'] = format
if options:
extra.update({
name.upper(): 'true' if value else 'false'
for name, value in options.items()
})
if extra:
prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items())
prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
return prefix
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
return (
"ON CONFLICT DO NOTHING"
if ignore_conflicts
else super().ignore_conflicts_suffix_sql(ignore_conflicts)
)
return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)
@@ -9,18 +9,16 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
sql_set_sequence_max = (
"SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
)
sql_set_sequence_owner = "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s"
sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
sql_set_sequence_owner = 'ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s'
sql_create_index = (
"CREATE INDEX %(name)s ON %(table)s%(using)s "
"(%(columns)s)%(include)s%(extra)s%(condition)s"
'CREATE INDEX %(name)s ON %(table)s%(using)s '
'(%(columns)s)%(include)s%(extra)s%(condition)s'
)
sql_create_index_concurrently = (
"CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
"(%(columns)s)%(include)s%(extra)s%(condition)s"
'CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s '
'(%(columns)s)%(include)s%(extra)s%(condition)s'
)
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
@@ -28,23 +26,21 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Setting the constraint to IMMEDIATE to allow changing data in the same
# transaction.
sql_create_column_inline_fk = (
"CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
"; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
'; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE'
)
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
# dropping it in the same transaction.
sql_delete_fk = (
"SET CONSTRAINTS %(name)s IMMEDIATE; "
"ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
)
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
sql_delete_fk = "SET CONSTRAINTS %(name)s IMMEDIATE; ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_delete_procedure = 'DROP FUNCTION %(procedure)s(%(param_types)s)'
def quote_value(self, value):
if isinstance(value, str):
value = value.replace("%", "%%")
value = value.replace('%', '%%')
adapted = psycopg2.extensions.adapt(value)
if hasattr(adapted, "encoding"):
adapted.encoding = "utf8"
if hasattr(adapted, 'encoding'):
adapted.encoding = 'utf8'
# getquoted() returns a quoted bytestring of the adapted value.
return adapted.getquoted().decode()
@@ -65,7 +61,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _field_base_data_types(self, field):
# Yield base data types for array fields.
if field.base_field.get_internal_type() == "ArrayField":
if field.base_field.get_internal_type() == 'ArrayField':
yield from self._field_base_data_types(field.base_field)
else:
yield self._field_data_type(field.base_field)
@@ -84,52 +80,45 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
#
# The same doesn't apply to array fields such as varchar[size]
# and text[size], so skip them.
if "[" in db_type:
if '[' in db_type:
return None
if db_type.startswith("varchar"):
if db_type.startswith('varchar'):
return self._create_index_sql(
model,
fields=[field],
suffix="_like",
opclasses=["varchar_pattern_ops"],
suffix='_like',
opclasses=['varchar_pattern_ops'],
)
elif db_type.startswith("text"):
elif db_type.startswith('text'):
return self._create_index_sql(
model,
fields=[field],
suffix="_like",
opclasses=["text_pattern_ops"],
suffix='_like',
opclasses=['text_pattern_ops'],
)
return None
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
self.sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
self.sql_alter_column_type = 'ALTER COLUMN %(column)s TYPE %(type)s'
# Cast when data type changed.
using_sql = " USING %(column)s::%(type)s"
using_sql = ' USING %(column)s::%(type)s'
new_internal_type = new_field.get_internal_type()
old_internal_type = old_field.get_internal_type()
if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
if new_internal_type == 'ArrayField' and new_internal_type == old_internal_type:
# Compare base data types for array fields.
if list(self._field_base_data_types(old_field)) != list(
self._field_base_data_types(new_field)
):
if list(self._field_base_data_types(old_field)) != list(self._field_base_data_types(new_field)):
self.sql_alter_column_type += using_sql
elif self._field_data_type(old_field) != self._field_data_type(new_field):
self.sql_alter_column_type += using_sql
# Make ALTER TYPE with SERIAL make sense.
table = strip_quotes(model._meta.db_table)
serial_fields_map = {
"bigserial": "bigint",
"serial": "integer",
"smallserial": "smallint",
}
serial_fields_map = {'bigserial': 'bigint', 'serial': 'integer', 'smallserial': 'smallint'}
if new_type.lower() in serial_fields_map:
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
return (
(
self.sql_alter_column_type
% {
self.sql_alter_column_type % {
"column": self.quote_name(column),
"type": serial_fields_map[new_type.lower()],
},
@@ -137,35 +126,29 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
),
[
(
self.sql_delete_sequence
% {
self.sql_delete_sequence % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_create_sequence
% {
self.sql_create_sequence % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_alter_column
% {
self.sql_alter_column % {
"table": self.quote_name(table),
"changes": self.sql_alter_column_default
% {
"changes": self.sql_alter_column_default % {
"column": self.quote_name(column),
"default": "nextval('%s')"
% self.quote_name(sequence_name),
},
"default": "nextval('%s')" % self.quote_name(sequence_name),
}
},
[],
),
(
self.sql_set_sequence_max
% {
self.sql_set_sequence_max % {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
@@ -173,31 +156,24 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
[],
),
(
self.sql_set_sequence_owner
% {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
self.sql_set_sequence_owner % {
'table': self.quote_name(table),
'column': self.quote_name(column),
'sequence': self.quote_name(sequence_name),
},
[],
),
],
)
elif (
old_field.db_parameters(connection=self.connection)["type"]
in serial_fields_map
):
elif old_field.db_parameters(connection=self.connection)['type'] in serial_fields_map:
# Drop the sequence if migrating away from AutoField.
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
fragment, _ = super()._alter_column_type_sql(
model, old_field, new_field, new_type
)
sequence_name = '%s_%s_seq' % (table, column)
fragment, _ = super()._alter_column_type_sql(model, old_field, new_field, new_type)
return fragment, [
(
self.sql_delete_sequence
% {
"sequence": self.quote_name(sequence_name),
self.sql_delete_sequence % {
'sequence': self.quote_name(sequence_name),
},
[],
),
@@ -205,114 +181,58 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
else:
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
def _alter_field(
self,
model,
old_field,
new_field,
old_type,
new_type,
old_db_params,
new_db_params,
strict=False,
):
def _alter_field(self, model, old_field, new_field, old_type, new_type,
old_db_params, new_db_params, strict=False):
# Drop indexes on varchar/text/citext columns that are changing to a
# different type.
if (old_field.db_index or old_field.unique) and (
(old_type.startswith("varchar") and not new_type.startswith("varchar"))
or (old_type.startswith("text") and not new_type.startswith("text"))
or (old_type.startswith("citext") and not new_type.startswith("citext"))
(old_type.startswith('varchar') and not new_type.startswith('varchar')) or
(old_type.startswith('text') and not new_type.startswith('text')) or
(old_type.startswith('citext') and not new_type.startswith('citext'))
):
index_name = self._create_index_name(
model._meta.db_table, [old_field.column], suffix="_like"
)
index_name = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
self.execute(self._delete_index_sql(model, index_name))
super()._alter_field(
model,
old_field,
new_field,
old_type,
new_type,
old_db_params,
new_db_params,
strict,
model, old_field, new_field, old_type, new_type, old_db_params,
new_db_params, strict,
)
# Added an index? Create any PostgreSQL-specific indexes.
if (not (old_field.db_index or old_field.unique) and new_field.db_index) or (
not old_field.unique and new_field.unique
):
if ((not (old_field.db_index or old_field.unique) and new_field.db_index) or
(not old_field.unique and new_field.unique)):
like_index_statement = self._create_like_index_sql(model, new_field)
if like_index_statement is not None:
self.execute(like_index_statement)
# Removed an index? Drop any PostgreSQL-specific indexes.
if old_field.unique and not (new_field.db_index or new_field.unique):
index_to_remove = self._create_index_name(
model._meta.db_table, [old_field.column], suffix="_like"
)
index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
self.execute(self._delete_index_sql(model, index_to_remove))
def _index_columns(self, table, columns, col_suffixes, opclasses):
if opclasses:
return IndexColumns(
table,
columns,
self.quote_name,
col_suffixes=col_suffixes,
opclasses=opclasses,
)
return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses)
return super()._index_columns(table, columns, col_suffixes, opclasses)
def add_index(self, model, index, concurrently=False):
self.execute(
index.create_sql(model, self, concurrently=concurrently), params=None
)
self.execute(index.create_sql(model, self, concurrently=concurrently), params=None)
def remove_index(self, model, index, concurrently=False):
self.execute(index.remove_sql(model, self, concurrently=concurrently))
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
sql = (
self.sql_delete_index_concurrently
if concurrently
else self.sql_delete_index
)
sql = self.sql_delete_index_concurrently if concurrently else self.sql_delete_index
return super()._delete_index_sql(model, name, sql)
def _create_index_sql(
self,
model,
*,
fields=None,
name=None,
suffix="",
using="",
db_tablespace=None,
col_suffixes=(),
sql=None,
opclasses=(),
condition=None,
concurrently=False,
include=None,
expressions=None,
self, model, *, fields=None, name=None, suffix='', using='',
db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
condition=None, concurrently=False, include=None, expressions=None,
):
sql = (
self.sql_create_index
if not concurrently
else self.sql_create_index_concurrently
)
sql = self.sql_create_index if not concurrently else self.sql_create_index_concurrently
return super()._create_index_sql(
model,
fields=fields,
name=name,
suffix=suffix,
using=using,
db_tablespace=db_tablespace,
col_suffixes=col_suffixes,
sql=sql,
opclasses=opclasses,
condition=condition,
include=include,
model, fields=fields, name=name, suffix=suffix, using=using,
db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql,
opclasses=opclasses, condition=condition, include=include,
expressions=expressions,
)