测试gitnore

This commit is contained in:
ladeng07
2022-05-06 15:45:57 +08:00
parent 12f390949b
commit 51552904f9
2347 changed files with 120102 additions and 53549 deletions
@@ -6,10 +6,7 @@ import warnings
from collections import deque
from contextlib import contextmanager
try:
import zoneinfo
except ImportError:
from backports import zoneinfo
import pytz
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
@@ -23,21 +20,11 @@ from django.utils import timezone
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
NO_DB_ALIAS = "__no_db__"
# RemovedInDjango50Warning
def timezone_constructor(tzname):
if settings.USE_DEPRECATED_PYTZ:
import pytz
return pytz.timezone(tzname)
return zoneinfo.ZoneInfo(tzname)
NO_DB_ALIAS = '__no_db__'
class BaseDatabaseWrapper:
"""Represent a database connection."""
# Mapping of Field objects to their column types.
data_types = {}
# Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
@@ -45,8 +32,8 @@ class BaseDatabaseWrapper:
# Mapping of Field objects to their SQL for CHECK constraints.
data_type_check_constraints = {}
ops = None
vendor = "unknown"
display_name = "unknown"
vendor = 'unknown'
display_name = 'unknown'
SchemaEditorClass = None
# Classes instantiated in __init__().
client_class = None
@@ -145,10 +132,10 @@ class BaseDatabaseWrapper:
"""
if not settings.USE_TZ:
return None
elif self.settings_dict["TIME_ZONE"] is None:
elif self.settings_dict['TIME_ZONE'] is None:
return timezone.utc
else:
return timezone_constructor(self.settings_dict["TIME_ZONE"])
return pytz.timezone(self.settings_dict['TIME_ZONE'])
@cached_property
def timezone_name(self):
@@ -157,10 +144,10 @@ class BaseDatabaseWrapper:
"""
if not settings.USE_TZ:
return settings.TIME_ZONE
elif self.settings_dict["TIME_ZONE"] is None:
return "UTC"
elif self.settings_dict['TIME_ZONE'] is None:
return 'UTC'
else:
return self.settings_dict["TIME_ZONE"]
return self.settings_dict['TIME_ZONE']
@property
def queries_logged(self):
@@ -171,38 +158,26 @@ class BaseDatabaseWrapper:
if len(self.queries_log) == self.queries_log.maxlen:
warnings.warn(
"Limit for query logging exceeded, only the last {} queries "
"will be returned.".format(self.queries_log.maxlen)
)
"will be returned.".format(self.queries_log.maxlen))
return list(self.queries_log)
# ##### Backend-specific methods for creating connections and cursors #####
def get_connection_params(self):
"""Return a dict of parameters suitable for get_new_connection."""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a get_connection_params() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_connection_params() method')
def get_new_connection(self, conn_params):
"""Open a connection to the database."""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a get_new_connection() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a get_new_connection() method')
def init_connection_state(self):
"""Initialize the database connection settings."""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require an init_connection_state() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require an init_connection_state() method')
def create_cursor(self, name=None):
"""Create a cursor. Assume that a connection is established."""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a create_cursor() method"
)
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a create_cursor() method')
# ##### Backend-specific methods for creating connections #####
@@ -216,21 +191,21 @@ class BaseDatabaseWrapper:
self.savepoint_ids = []
self.needs_rollback = False
# Reset parameters defining when to close the connection
max_age = self.settings_dict["CONN_MAX_AGE"]
max_age = self.settings_dict['CONN_MAX_AGE']
self.close_at = None if max_age is None else time.monotonic() + max_age
self.closed_in_transaction = False
self.errors_occurred = False
# Establish the connection
conn_params = self.get_connection_params()
self.connection = self.get_new_connection(conn_params)
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
self.set_autocommit(self.settings_dict['AUTOCOMMIT'])
self.init_connection_state()
connection_created.send(sender=self.__class__, connection=self)
self.run_on_commit = []
def check_settings(self):
if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
if self.settings_dict['TIME_ZONE'] is not None and not settings.USE_TZ:
raise ImproperlyConfigured(
"Connection '%s' cannot set TIME_ZONE because USE_TZ is False."
% self.alias
@@ -355,7 +330,7 @@ class BaseDatabaseWrapper:
return
thread_ident = _thread.get_ident()
tid = str(thread_ident).replace("-", "")
tid = str(thread_ident).replace('-', '')
self.savepoint_state += 1
sid = "s%s_x%d" % (tid, self.savepoint_state)
@@ -405,9 +380,7 @@ class BaseDatabaseWrapper:
"""
Backend-specific implementation to enable or disable autocommit.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
)
raise NotImplementedError('subclasses of BaseDatabaseWrapper may require a _set_autocommit() method')
# ##### Generic transaction management methods #####
@@ -416,9 +389,7 @@ class BaseDatabaseWrapper:
self.ensure_connection()
return self.autocommit
def set_autocommit(
self, autocommit, force_begin_transaction_with_broken_autocommit=False
):
def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
"""
Enable or disable autocommit.
@@ -434,9 +405,8 @@ class BaseDatabaseWrapper:
self.ensure_connection()
start_transaction_under_autocommit = (
force_begin_transaction_with_broken_autocommit
and not autocommit
and hasattr(self, "_start_transaction_under_autocommit")
force_begin_transaction_with_broken_autocommit and not autocommit and
hasattr(self, '_start_transaction_under_autocommit')
)
if start_transaction_under_autocommit:
@@ -454,8 +424,7 @@ class BaseDatabaseWrapper:
"""Get the "needs rollback" flag -- for *advanced use* only."""
if not self.in_atomic_block:
raise TransactionManagementError(
"The rollback flag doesn't work outside of an 'atomic' block."
)
"The rollback flag doesn't work outside of an 'atomic' block.")
return self.needs_rollback
def set_rollback(self, rollback):
@@ -464,23 +433,20 @@ class BaseDatabaseWrapper:
"""
if not self.in_atomic_block:
raise TransactionManagementError(
"The rollback flag doesn't work outside of an 'atomic' block."
)
"The rollback flag doesn't work outside of an 'atomic' block.")
self.needs_rollback = rollback
def validate_no_atomic_block(self):
"""Raise an error if an atomic block is active."""
if self.in_atomic_block:
raise TransactionManagementError(
"This is forbidden when an 'atomic' block is active."
)
"This is forbidden when an 'atomic' block is active.")
def validate_no_broken_transaction(self):
if self.needs_rollback:
raise TransactionManagementError(
"An error occurred in the current transaction. You can't "
"execute queries until the end of the 'atomic' block."
)
"execute queries until the end of the 'atomic' block.")
# ##### Foreign key constraints checks handling #####
@@ -531,8 +497,7 @@ class BaseDatabaseWrapper:
as that may prevent Django from recycling unusable connections.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseWrapper may require an is_usable() method"
)
"subclasses of BaseDatabaseWrapper may require an is_usable() method")
def close_if_unusable_or_obsolete(self):
"""
@@ -542,7 +507,7 @@ class BaseDatabaseWrapper:
if self.connection is not None:
# If the application didn't restore the original autocommit setting,
# don't take chances, drop the connection.
if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
if self.get_autocommit() != self.settings_dict['AUTOCOMMIT']:
self.close()
return
@@ -573,9 +538,7 @@ class BaseDatabaseWrapper:
def dec_thread_sharing(self):
with self._thread_sharing_lock:
if self._thread_sharing_count <= 0:
raise RuntimeError(
"Cannot decrement the thread sharing count below zero."
)
raise RuntimeError('Cannot decrement the thread sharing count below zero.')
self._thread_sharing_count -= 1
def validate_thread_sharing(self):
@@ -590,7 +553,8 @@ class BaseDatabaseWrapper:
"DatabaseWrapper objects created in a "
"thread can only be used in that same thread. The object "
"with alias '%s' was created in thread id %s and this is "
"thread id %s." % (self.alias, self._thread_ident, _thread.get_ident())
"thread id %s."
% (self.alias, self._thread_ident, _thread.get_ident())
)
# ##### Miscellaneous #####
@@ -651,7 +615,7 @@ class BaseDatabaseWrapper:
being exposed to potential child threads while (or after) the test
database is destroyed. Refs #10868, #17786, #16969.
"""
conn = self.__class__({**self.settings_dict, "NAME": None}, alias=NO_DB_ALIAS)
conn = self.__class__({**self.settings_dict, 'NAME': None}, alias=NO_DB_ALIAS)
try:
with conn.cursor() as cursor:
yield cursor
@@ -664,8 +628,7 @@ class BaseDatabaseWrapper:
"""
if self.SchemaEditorClass is None:
raise NotImplementedError(
"The SchemaEditorClass attribute of this database wrapper is still None"
)
'The SchemaEditorClass attribute of this database wrapper is still None')
return self.SchemaEditorClass(self, *args, **kwargs)
def on_commit(self, func):
@@ -675,9 +638,7 @@ class BaseDatabaseWrapper:
# Transaction in progress; save for execution on commit.
self.run_on_commit.append((set(self.savepoint_ids), func))
elif not self.get_autocommit():
raise TransactionManagementError(
"on_commit() cannot be used in manual transaction management"
)
raise TransactionManagementError('on_commit() cannot be used in manual transaction management')
else:
# No transaction in progress and in autocommit mode; execute
# immediately.
@@ -4,7 +4,6 @@ import subprocess
class BaseDatabaseClient:
"""Encapsulate backend-specific methods for opening a client shell."""
# This should be a string representing the name of the executable
# (e.g., "psql"). Subclasses must override this.
executable_name = None
@@ -16,13 +15,11 @@ class BaseDatabaseClient:
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
raise NotImplementedError(
"subclasses of BaseDatabaseClient must provide a "
"settings_to_cmd_args_env() method or override a runshell()."
'subclasses of BaseDatabaseClient must provide a '
'settings_to_cmd_args_env() method or override a runshell().'
)
def runshell(self, parameters):
args, env = self.settings_to_cmd_args_env(
self.connection.settings_dict, parameters
)
args, env = self.settings_to_cmd_args_env(self.connection.settings_dict, parameters)
env = {**os.environ, **env} if env else None
subprocess.run(args, env=env, check=True)
@@ -12,7 +12,7 @@ from django.utils.module_loading import import_string
# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = "test_"
TEST_DATABASE_PREFIX = 'test_'
class BaseDatabaseCreation:
@@ -20,7 +20,6 @@ class BaseDatabaseCreation:
Encapsulate backend-specific differences pertaining to creation and
destruction of the test database.
"""
def __init__(self, connection):
self.connection = connection
@@ -30,9 +29,7 @@ class BaseDatabaseCreation:
def log(self, msg):
sys.stderr.write(msg + os.linesep)
def create_test_db(
self, verbosity=1, autoclobber=False, serialize=True, keepdb=False
):
def create_test_db(self, verbosity=1, autoclobber=False, serialize=True, keepdb=False):
"""
Create a test database, prompting the user for confirmation if the
database already exists. Return the name of the test database created.
@@ -43,17 +40,14 @@ class BaseDatabaseCreation:
test_database_name = self._get_test_db_name()
if verbosity >= 1:
action = "Creating"
action = 'Creating'
if keepdb:
action = "Using existing"
self.log(
"%s test database for alias %s..."
% (
action,
self._get_database_display_str(verbosity, test_database_name),
)
)
self.log('%s test database for alias %s...' % (
action,
self._get_database_display_str(verbosity, test_database_name),
))
# We could skip this call if keepdb is True, but we instead
# give it the keepdb param. This is to handle the case
@@ -67,24 +61,25 @@ class BaseDatabaseCreation:
self.connection.settings_dict["NAME"] = test_database_name
try:
if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
if self.connection.settings_dict['TEST']['MIGRATE'] is False:
# Disable migrations for all apps.
old_migration_modules = settings.MIGRATION_MODULES
settings.MIGRATION_MODULES = {
app.label: None for app in apps.get_app_configs()
app.label: None
for app in apps.get_app_configs()
}
# We report migrate messages at one level lower than that
# requested. This ensures we don't get flooded with messages during
# testing (unless you really ask to be flooded).
call_command(
"migrate",
'migrate',
verbosity=max(verbosity - 1, 0),
interactive=False,
database=self.connection.alias,
run_syncdb=True,
)
finally:
if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
if self.connection.settings_dict['TEST']['MIGRATE'] is False:
settings.MIGRATION_MODULES = old_migration_modules
# We then serialize the current state of the database into a string
@@ -94,12 +89,12 @@ class BaseDatabaseCreation:
if serialize:
self.connection._test_serialized_contents = self.serialize_db_to_string()
call_command("createcachetable", database=self.connection.alias)
call_command('createcachetable', database=self.connection.alias)
# Ensure a connection for the side effect of initializing the test database.
self.connection.ensure_connection()
if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
if os.environ.get('RUNNING_DJANGOS_TEST_SUITE') == 'true':
self.mark_expected_failures_and_skips()
return test_database_name
@@ -109,7 +104,7 @@ class BaseDatabaseCreation:
Set this database up to be used in testing as a mirror of a primary
database whose settings are given.
"""
self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
self.connection.settings_dict['NAME'] = primary_settings_dict['NAME']
def serialize_db_to_string(self):
"""
@@ -120,23 +115,22 @@ class BaseDatabaseCreation:
# Iteratively return every object for all models to serialize.
def get_objects():
from django.db.migrations.loader import MigrationLoader
loader = MigrationLoader(self.connection)
for app_config in apps.get_app_configs():
if (
app_config.models_module is not None
and app_config.label in loader.migrated_apps
and app_config.name not in settings.TEST_NON_SERIALIZED_APPS
app_config.models_module is not None and
app_config.label in loader.migrated_apps and
app_config.name not in settings.TEST_NON_SERIALIZED_APPS
):
for model in app_config.get_models():
if model._meta.can_migrate(
self.connection
) and router.allow_migrate_model(self.connection.alias, model):
if (
model._meta.can_migrate(self.connection) and
router.allow_migrate_model(self.connection.alias, model)
):
queryset = model._base_manager.using(
self.connection.alias,
).order_by(model._meta.pk.name)
yield from queryset.iterator()
# Serialize to a string
out = StringIO()
serializers.serialize("json", get_objects(), indent=None, stream=out)
@@ -154,9 +148,7 @@ class BaseDatabaseCreation:
# Disable constraint checks, because some databases (MySQL) doesn't
# support deferred checks.
with self.connection.constraint_checks_disabled():
for obj in serializers.deserialize(
"json", data, using=self.connection.alias
):
for obj in serializers.deserialize('json', data, using=self.connection.alias):
obj.save()
table_names.add(obj.object.__class__._meta.db_table)
# Manually check for any invalid keys that might have been added,
@@ -169,7 +161,7 @@ class BaseDatabaseCreation:
"""
return "'%s'%s" % (
self.connection.alias,
(" ('%s')" % database_name) if verbosity >= 2 else "",
(" ('%s')" % database_name) if verbosity >= 2 else '',
)
def _get_test_db_name(self):
@@ -179,12 +171,12 @@ class BaseDatabaseCreation:
_create_test_db() and when no external munging is done with the 'NAME'
settings.
"""
if self.connection.settings_dict["TEST"]["NAME"]:
return self.connection.settings_dict["TEST"]["NAME"]
return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]
if self.connection.settings_dict['TEST']['NAME']:
return self.connection.settings_dict['TEST']['NAME']
return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME']
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
cursor.execute("CREATE DATABASE %(dbname)s %(suffix)s" % parameters)
cursor.execute('CREATE DATABASE %(dbname)s %(suffix)s' % parameters)
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
"""
@@ -192,8 +184,8 @@ class BaseDatabaseCreation:
"""
test_database_name = self._get_test_db_name()
test_db_params = {
"dbname": self.connection.ops.quote_name(test_database_name),
"suffix": self.sql_table_creation_suffix(),
'dbname': self.connection.ops.quote_name(test_database_name),
'suffix': self.sql_table_creation_suffix(),
}
# Create the test database and connect to it.
with self._nodb_cursor() as cursor:
@@ -205,30 +197,24 @@ class BaseDatabaseCreation:
if keepdb:
return test_database_name
self.log("Got an error creating the test database: %s" % e)
self.log('Got an error creating the test database: %s' % e)
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name
)
if autoclobber or confirm == "yes":
"database '%s', or 'no' to cancel: " % test_database_name)
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, test_database_name
),
)
)
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, test_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log("Got an error recreating the test database: %s" % e)
self.log('Got an error recreating the test database: %s' % e)
sys.exit(2)
else:
self.log("Tests cancelled.")
self.log('Tests cancelled.')
sys.exit(1)
return test_database_name
@@ -237,19 +223,16 @@ class BaseDatabaseCreation:
"""
Clone a test database.
"""
source_database_name = self.connection.settings_dict["NAME"]
source_database_name = self.connection.settings_dict['NAME']
if verbosity >= 1:
action = "Cloning test database"
action = 'Cloning test database'
if keepdb:
action = "Using existing clone"
self.log(
"%s for alias %s..."
% (
action,
self._get_database_display_str(verbosity, source_database_name),
)
)
action = 'Using existing clone'
self.log('%s for alias %s...' % (
action,
self._get_database_display_str(verbosity, source_database_name),
))
# We could skip this call if keepdb is True, but we instead
# give it the keepdb param. See create_test_db for details.
@@ -263,10 +246,7 @@ class BaseDatabaseCreation:
# already and its name has been copied to settings_dict['NAME'] so
# we don't need to call _get_test_db_name.
orig_settings_dict = self.connection.settings_dict
return {
**orig_settings_dict,
"NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix),
}
return {**orig_settings_dict, 'NAME': '{}_{}'.format(orig_settings_dict['NAME'], suffix)}
def _clone_test_db(self, suffix, verbosity, keepdb=False):
"""
@@ -274,33 +254,27 @@ class BaseDatabaseCreation:
"""
raise NotImplementedError(
"The database backend doesn't support cloning databases. "
"Disable the option to run tests in parallel processes."
)
"Disable the option to run tests in parallel processes.")
def destroy_test_db(
self, old_database_name=None, verbosity=1, keepdb=False, suffix=None
):
def destroy_test_db(self, old_database_name=None, verbosity=1, keepdb=False, suffix=None):
"""
Destroy a test database, prompting the user for confirmation if the
database already exists.
"""
self.connection.close()
if suffix is None:
test_database_name = self.connection.settings_dict["NAME"]
test_database_name = self.connection.settings_dict['NAME']
else:
test_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
test_database_name = self.get_test_db_clone_settings(suffix)['NAME']
if verbosity >= 1:
action = "Destroying"
action = 'Destroying'
if keepdb:
action = "Preserving"
self.log(
"%s test database for alias %s..."
% (
action,
self._get_database_display_str(verbosity, test_database_name),
)
)
action = 'Preserving'
self.log('%s test database for alias %s...' % (
action,
self._get_database_display_str(verbosity, test_database_name),
))
# if we want to preserve the database
# skip the actual destroying piece.
@@ -321,9 +295,8 @@ class BaseDatabaseCreation:
# to do so, because it's not allowed to delete a database while being
# connected to it.
with self._nodb_cursor() as cursor:
cursor.execute(
"DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)
)
cursor.execute("DROP DATABASE %s"
% self.connection.ops.quote_name(test_database_name))
def mark_expected_failures_and_skips(self):
"""
@@ -331,8 +304,8 @@ class BaseDatabaseCreation:
database and test which should be skipped on this database.
"""
for test_name in self.connection.features.django_test_expected_failures:
test_case_name, _, test_method_name = test_name.rpartition(".")
test_app = test_name.split(".")[0]
test_case_name, _, test_method_name = test_name.rpartition('.')
test_app = test_name.split('.')[0]
# Importing a test app that isn't installed raises RuntimeError.
if test_app in settings.INSTALLED_APPS:
test_case = import_string(test_case_name)
@@ -340,8 +313,8 @@ class BaseDatabaseCreation:
setattr(test_case, test_method_name, expectedFailure(test_method))
for reason, tests in self.connection.features.django_test_skips.items():
for test_name in tests:
test_case_name, _, test_method_name = test_name.rpartition(".")
test_app = test_name.split(".")[0]
test_case_name, _, test_method_name = test_name.rpartition('.')
test_app = test_name.split('.')[0]
# Importing a test app that isn't installed raises RuntimeError.
if test_app in settings.INSTALLED_APPS:
test_case = import_string(test_case_name)
@@ -352,7 +325,7 @@ class BaseDatabaseCreation:
"""
SQL to append to the end of the test table creation statements.
"""
return ""
return ''
def test_db_signature(self):
"""
@@ -362,8 +335,8 @@ class BaseDatabaseCreation:
"""
settings_dict = self.connection.settings_dict
return (
settings_dict["HOST"],
settings_dict["PORT"],
settings_dict["ENGINE"],
settings_dict['HOST'],
settings_dict['PORT'],
settings_dict['ENGINE'],
self._get_test_db_name(),
)
@@ -64,9 +64,6 @@ class BaseDatabaseFeatures:
has_real_datatype = False
supports_subqueries_in_group_by = True
# Does the backend ignore unnecessary ORDER BY clauses in subqueries?
ignores_unnecessary_order_by_in_subqueries = True
# Is there a true datatype for uuid?
has_native_uuid_field = False
@@ -133,21 +130,21 @@ class BaseDatabaseFeatures:
# Map fields which some backends may not be able to differentiate to the
# field it's introspected as.
introspected_field_types = {
"AutoField": "AutoField",
"BigAutoField": "BigAutoField",
"BigIntegerField": "BigIntegerField",
"BinaryField": "BinaryField",
"BooleanField": "BooleanField",
"CharField": "CharField",
"DurationField": "DurationField",
"GenericIPAddressField": "GenericIPAddressField",
"IntegerField": "IntegerField",
"PositiveBigIntegerField": "PositiveBigIntegerField",
"PositiveIntegerField": "PositiveIntegerField",
"PositiveSmallIntegerField": "PositiveSmallIntegerField",
"SmallAutoField": "SmallAutoField",
"SmallIntegerField": "SmallIntegerField",
"TimeField": "TimeField",
'AutoField': 'AutoField',
'BigAutoField': 'BigAutoField',
'BigIntegerField': 'BigIntegerField',
'BinaryField': 'BinaryField',
'BooleanField': 'BooleanField',
'CharField': 'CharField',
'DurationField': 'DurationField',
'GenericIPAddressField': 'GenericIPAddressField',
'IntegerField': 'IntegerField',
'PositiveBigIntegerField': 'PositiveBigIntegerField',
'PositiveIntegerField': 'PositiveIntegerField',
'PositiveSmallIntegerField': 'PositiveSmallIntegerField',
'SmallAutoField': 'SmallAutoField',
'SmallIntegerField': 'SmallIntegerField',
'TimeField': 'TimeField',
}
# Can the backend introspect the column order (ASC/DESC) for indexes?
@@ -204,7 +201,7 @@ class BaseDatabaseFeatures:
has_case_insensitive_like = True
# Suffix for backends that don't support "SELECT xxx;" queries.
bare_select_suffix = ""
bare_select_suffix = ''
# If NULL is implied on columns without needing to be explicitly specified
implied_column_null = False
@@ -324,13 +321,11 @@ class BaseDatabaseFeatures:
# Collation names for use by the Django test suite.
test_collations = {
"ci": None, # Case-insensitive.
"cs": None, # Case-sensitive.
"non_default": None, # Non-default.
"swedish_ci": None, # Swedish case-insensitive.
'ci': None, # Case-insensitive.
'cs': None, # Case-sensitive.
'non_default': None, # Non-default.
'swedish_ci': None # Swedish case-insensitive.
}
# SQL template override for tests.aggregation.tests.NowUTC
test_now_utc_template = None
# A set of dotted paths to tests in Django's test suite that are expected
# to fail on this database.
@@ -351,14 +346,14 @@ class BaseDatabaseFeatures:
def supports_transactions(self):
"""Confirm support for transactions."""
with self.connection.cursor() as cursor:
cursor.execute("CREATE TABLE ROLLBACK_TEST (X INT)")
cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)')
self.connection.set_autocommit(False)
cursor.execute("INSERT INTO ROLLBACK_TEST (X) VALUES (8)")
cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)')
self.connection.rollback()
self.connection.set_autocommit(True)
cursor.execute("SELECT COUNT(X) FROM ROLLBACK_TEST")
(count,) = cursor.fetchone()
cursor.execute("DROP TABLE ROLLBACK_TEST")
cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST')
count, = cursor.fetchone()
cursor.execute('DROP TABLE ROLLBACK_TEST')
return count == 0
def allows_group_by_selected_pks_on_model(self, model):
@@ -1,19 +1,18 @@
from collections import namedtuple
# Structure returned by DatabaseIntrospection.get_table_list()
TableInfo = namedtuple("TableInfo", ["name", "type"])
TableInfo = namedtuple('TableInfo', ['name', 'type'])
# Structure returned by the DB-API cursor.description interface (PEP 249)
FieldInfo = namedtuple(
"FieldInfo",
"name type_code display_size internal_size precision scale null_ok "
"default collation",
'FieldInfo',
'name type_code display_size internal_size precision scale null_ok '
'default collation'
)
class BaseDatabaseIntrospection:
"""Encapsulate backend-specific introspection utilities."""
data_types_reverse = {}
def __init__(self, connection):
@@ -44,14 +43,9 @@ class BaseDatabaseIntrospection:
the database's ORDER BY here to avoid subtle differences in sorting
order between databases.
"""
def get_names(cursor):
return sorted(
ti.name
for ti in self.get_table_list(cursor)
if include_views or ti.type == "t"
)
return sorted(ti.name for ti in self.get_table_list(cursor)
if include_views or ti.type == 't')
if cursor is None:
with self.connection.cursor() as cursor:
return get_names(cursor)
@@ -62,10 +56,7 @@ class BaseDatabaseIntrospection:
Return an unsorted list of TableInfo named tuples of all tables and
views that exist in the database.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseIntrospection may require a get_table_list() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_table_list() method')
def get_table_description(self, cursor, table_name):
"""
@@ -73,14 +64,13 @@ class BaseDatabaseIntrospection:
interface.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseIntrospection may require a "
"get_table_description() method."
'subclasses of BaseDatabaseIntrospection may require a '
'get_table_description() method.'
)
def get_migratable_models(self):
from django.apps import apps
from django.db import router
return (
model
for app_config in apps.get_app_configs()
@@ -101,15 +91,16 @@ class BaseDatabaseIntrospection:
continue
tables.add(model._meta.db_table)
tables.update(
f.m2m_db_table()
for f in model._meta.local_many_to_many
f.m2m_db_table() for f in model._meta.local_many_to_many
if f.remote_field.through._meta.managed
)
tables = list(tables)
if only_existing:
existing_tables = set(self.table_names(include_views=include_views))
tables = [
t for t in tables if self.identifier_converter(t) in existing_tables
t
for t in tables
if self.identifier_converter(t) in existing_tables
]
return tables
@@ -120,8 +111,7 @@ class BaseDatabaseIntrospection:
"""
tables = set(map(self.identifier_converter, tables))
return {
m
for m in self.get_migratable_models()
m for m in self.get_migratable_models()
if self.identifier_converter(m._meta.db_table) in tables
}
@@ -137,19 +127,13 @@ class BaseDatabaseIntrospection:
continue
if model._meta.swapped:
continue
sequence_list.extend(
self.get_sequences(
cursor, model._meta.db_table, model._meta.local_fields
)
)
sequence_list.extend(self.get_sequences(cursor, model._meta.db_table, model._meta.local_fields))
for f in model._meta.local_many_to_many:
# If this is an m2m using an intermediate table,
# we don't need to reset the sequence.
if f.remote_field.through._meta.auto_created:
sequence = self.get_sequences(cursor, f.m2m_db_table())
sequence_list.extend(
sequence or [{"table": f.m2m_db_table(), "column": None}]
)
sequence_list.extend(sequence or [{'table': f.m2m_db_table(), 'column': None}])
return sequence_list
def get_sequences(self, cursor, table_name, table_fields=()):
@@ -158,10 +142,7 @@ class BaseDatabaseIntrospection:
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
'name' key can be added if the backend supports named sequences.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseIntrospection may require a get_sequences() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_sequences() method')
def get_relations(self, cursor, table_name):
"""
@@ -170,8 +151,8 @@ class BaseDatabaseIntrospection:
relationships to the given table.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseIntrospection may require a "
"get_relations() method."
'subclasses of BaseDatabaseIntrospection may require a '
'get_relations() method.'
)
def get_key_columns(self, cursor, table_name):
@@ -180,18 +161,15 @@ class BaseDatabaseIntrospection:
(column_name, referenced_table_name, referenced_column_name)
for all key columns in given table.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseIntrospection may require a get_key_columns() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_key_columns() method')
def get_primary_key_column(self, cursor, table_name):
"""
Return the name of the primary key column for the given table.
"""
for constraint in self.get_constraints(cursor, table_name).values():
if constraint["primary_key"]:
return constraint["columns"][0]
if constraint['primary_key']:
return constraint['columns'][0]
return None
def get_constraints(self, cursor, table_name):
@@ -213,7 +191,4 @@ class BaseDatabaseIntrospection:
Some backends may return special constraint names that don't exist
if they don't name constraints of a certain type (e.g. SQLite)
"""
raise NotImplementedError(
"subclasses of BaseDatabaseIntrospection may require a get_constraints() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseIntrospection may require a get_constraints() method')
@@ -16,26 +16,25 @@ class BaseDatabaseOperations:
Encapsulate backend-specific differences, such as the way a backend
performs ordering or calculates the ID of a recently-inserted row.
"""
compiler_module = "django.db.models.sql.compiler"
# Integer field safe ranges by `internal_type` as documented
# in docs/ref/models/fields.txt.
integer_field_ranges = {
"SmallIntegerField": (-32768, 32767),
"IntegerField": (-2147483648, 2147483647),
"BigIntegerField": (-9223372036854775808, 9223372036854775807),
"PositiveBigIntegerField": (0, 9223372036854775807),
"PositiveSmallIntegerField": (0, 32767),
"PositiveIntegerField": (0, 2147483647),
"SmallAutoField": (-32768, 32767),
"AutoField": (-2147483648, 2147483647),
"BigAutoField": (-9223372036854775808, 9223372036854775807),
'SmallIntegerField': (-32768, 32767),
'IntegerField': (-2147483648, 2147483647),
'BigIntegerField': (-9223372036854775808, 9223372036854775807),
'PositiveBigIntegerField': (0, 9223372036854775807),
'PositiveSmallIntegerField': (0, 32767),
'PositiveIntegerField': (0, 2147483647),
'SmallAutoField': (-32768, 32767),
'AutoField': (-2147483648, 2147483647),
'BigAutoField': (-9223372036854775808, 9223372036854775807),
}
set_operators = {
"union": "UNION",
"intersection": "INTERSECT",
"difference": "EXCEPT",
'union': 'UNION',
'intersection': 'INTERSECT',
'difference': 'EXCEPT',
}
# Mapping of Field.get_internal_type() (typically the model field's class
# name) to the data type to use for the Cast() function, if different from
@@ -45,11 +44,11 @@ class BaseDatabaseOperations:
cast_char_field_without_max_length = None
# Start and end points for window expressions.
PRECEDING = "PRECEDING"
FOLLOWING = "FOLLOWING"
UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
CURRENT_ROW = "CURRENT ROW"
PRECEDING = 'PRECEDING'
FOLLOWING = 'FOLLOWING'
UNBOUNDED_PRECEDING = 'UNBOUNDED ' + PRECEDING
UNBOUNDED_FOLLOWING = 'UNBOUNDED ' + FOLLOWING
CURRENT_ROW = 'CURRENT ROW'
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
explain_prefix = None
@@ -91,17 +90,14 @@ class BaseDatabaseOperations:
to that type. The resulting string should contain a '%s' placeholder
for the expression being cast.
"""
return "%s"
return '%s'
def date_extract_sql(self, lookup_type, field_name):
"""
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
extracts a value from the given date field field_name.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a date_extract_sql() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_extract_sql() method')
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
"""
@@ -112,28 +108,22 @@ class BaseDatabaseOperations:
If `tzname` is provided, the given value is truncated in a specific
timezone.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a date_trunc_sql() "
"method."
)
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a date_trunc_sql() method.')
def datetime_cast_date_sql(self, field_name, tzname):
"""
Return the SQL to cast a datetime value to date value.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a "
"datetime_cast_date_sql() method."
'subclasses of BaseDatabaseOperations may require a '
'datetime_cast_date_sql() method.'
)
def datetime_cast_time_sql(self, field_name, tzname):
"""
Return the SQL to cast a datetime value to time value.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a "
"datetime_cast_time_sql() method"
)
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_cast_time_sql() method')
def datetime_extract_sql(self, lookup_type, field_name, tzname):
"""
@@ -141,10 +131,7 @@ class BaseDatabaseOperations:
'second', return the SQL that extracts a value from the given
datetime field field_name.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a datetime_extract_sql() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_extract_sql() method')
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
"""
@@ -152,10 +139,7 @@ class BaseDatabaseOperations:
'second', return the SQL that truncates the given datetime field
field_name to a datetime object with only the given specificity.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() "
"method"
)
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() method')
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
"""
@@ -166,9 +150,7 @@ class BaseDatabaseOperations:
If `tzname` is provided, the given value is truncated in a specific
timezone.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
)
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a time_trunc_sql() method')
def time_extract_sql(self, lookup_type, field_name):
"""
@@ -182,7 +164,7 @@ class BaseDatabaseOperations:
Return the SQL to make a constraint "initially deferred" during a
CREATE TABLE statement.
"""
return ""
return ''
def distinct_sql(self, fields, params):
"""
@@ -191,11 +173,9 @@ class BaseDatabaseOperations:
duplicates.
"""
if fields:
raise NotSupportedError(
"DISTINCT ON fields is not supported by this database backend"
)
raise NotSupportedError('DISTINCT ON fields is not supported by this database backend')
else:
return ["DISTINCT"], []
return ['DISTINCT'], []
def fetch_returned_insert_columns(self, cursor, returning_params):
"""
@@ -211,7 +191,7 @@ class BaseDatabaseOperations:
it in a WHERE statement. The resulting string should contain a '%s'
placeholder for the column being searched against.
"""
return "%s"
return '%s'
def force_no_ordering(self):
"""
@@ -224,11 +204,11 @@ class BaseDatabaseOperations:
"""
Return the FOR UPDATE SQL clause to lock rows for an update operation.
"""
return "FOR%s UPDATE%s%s%s" % (
" NO KEY" if no_key else "",
" OF %s" % ", ".join(of) if of else "",
" NOWAIT" if nowait else "",
" SKIP LOCKED" if skip_locked else "",
return 'FOR%s UPDATE%s%s%s' % (
' NO KEY' if no_key else '',
' OF %s' % ', '.join(of) if of else '',
' NOWAIT' if nowait else '',
' SKIP LOCKED' if skip_locked else '',
)
def _get_limit_offset_params(self, low_mark, high_mark):
@@ -242,14 +222,10 @@ class BaseDatabaseOperations:
def limit_offset_sql(self, low_mark, high_mark):
"""Return LIMIT/OFFSET SQL clause."""
limit, offset = self._get_limit_offset_params(low_mark, high_mark)
return " ".join(
sql
for sql in (
("LIMIT %d" % limit) if limit else None,
("OFFSET %d" % offset) if offset else None,
)
if sql
)
return ' '.join(sql for sql in (
('LIMIT %d' % limit) if limit else None,
('OFFSET %d' % offset) if offset else None,
) if sql)
def last_executed_query(self, cursor, sql, params):
"""
@@ -263,8 +239,7 @@ class BaseDatabaseOperations:
"""
# Convert params to contain string values.
def to_string(s):
return force_str(s, strings_only=True, errors="replace")
return force_str(s, strings_only=True, errors='replace')
if isinstance(params, (list, tuple)):
u_params = tuple(to_string(val) for val in params)
elif params is None:
@@ -310,16 +285,14 @@ class BaseDatabaseOperations:
Return the value to use for the LIMIT when we are wanting "LIMIT
infinity". Return None if the limit clause can be omitted in this case.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a no_limit_value() method"
)
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a no_limit_value() method')
def pk_default_value(self):
"""
Return the value to use during an INSERT statement to specify that
the field should use its default value.
"""
return "DEFAULT"
return 'DEFAULT'
def prepare_sql_script(self, sql):
"""
@@ -332,8 +305,7 @@ class BaseDatabaseOperations:
"""
return [
sqlparse.format(statement, strip_comments=True)
for statement in sqlparse.split(sql)
if statement
for statement in sqlparse.split(sql) if statement
]
def process_clob(self, value):
@@ -366,9 +338,7 @@ class BaseDatabaseOperations:
Return a quoted version of the given table, index, or column name. Do
not quote the given name if it's already been quoted.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a quote_name() method"
)
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a quote_name() method')
def regex_lookup(self, lookup_type):
"""
@@ -379,9 +349,7 @@ class BaseDatabaseOperations:
If the feature is not supported (or part of it is not supported), raise
NotImplementedError.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a regex_lookup() method"
)
raise NotImplementedError('subclasses of BaseDatabaseOperations may require a regex_lookup() method')
def savepoint_create_sql(self, sid):
"""
@@ -409,7 +377,7 @@ class BaseDatabaseOperations:
Return '' if the backend doesn't support time zones.
"""
return ""
return ''
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
"""
@@ -427,9 +395,7 @@ class BaseDatabaseOperations:
to tables with foreign keys pointing the tables being truncated.
PostgreSQL requires a cascade even if these tables are empty.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations must provide an sql_flush() method"
)
raise NotImplementedError('subclasses of BaseDatabaseOperations must provide an sql_flush() method')
def execute_sql_flush(self, sql_list):
"""Execute a list of SQL statements to flush the database."""
@@ -480,7 +446,7 @@ class BaseDatabaseOperations:
If `inline` is True, append the SQL to a row; otherwise append it to
the entire CREATE TABLE or CREATE INDEX statement.
"""
return ""
return ''
def prep_for_like_query(self, x):
"""Prepare a value for use in a LIKE query."""
@@ -506,7 +472,7 @@ class BaseDatabaseOperations:
cases where the target type isn't known, such as .raw() SQL queries.
As a consequence it may not work perfectly in all circumstances.
"""
if isinstance(value, datetime.datetime): # must be before date
if isinstance(value, datetime.datetime): # must be before date
return self.adapt_datetimefield_value(value)
elif isinstance(value, datetime.date):
return self.adapt_datefield_value(value)
@@ -560,44 +526,30 @@ class BaseDatabaseOperations:
"""
return value or None
def year_lookup_bounds_for_date_field(self, value, iso_year=False):
def year_lookup_bounds_for_date_field(self, value):
"""
Return a two-elements list with the lower and upper bound to be used
with a BETWEEN operator to query a DateField value using a year
lookup.
`value` is an int, containing the looked-up year.
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
"""
if iso_year:
first = datetime.date.fromisocalendar(value, 1, 1)
second = datetime.date.fromisocalendar(
value + 1, 1, 1
) - datetime.timedelta(days=1)
else:
first = datetime.date(value, 1, 1)
second = datetime.date(value, 12, 31)
first = datetime.date(value, 1, 1)
second = datetime.date(value, 12, 31)
first = self.adapt_datefield_value(first)
second = self.adapt_datefield_value(second)
return [first, second]
def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
def year_lookup_bounds_for_datetime_field(self, value):
"""
Return a two-elements list with the lower and upper bound to be used
with a BETWEEN operator to query a DateTimeField value using a year
lookup.
`value` is an int, containing the looked-up year.
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
"""
if iso_year:
first = datetime.datetime.fromisocalendar(value, 1, 1)
second = datetime.datetime.fromisocalendar(
value + 1, 1, 1
) - datetime.timedelta(microseconds=1)
else:
first = datetime.datetime(value, 1, 1)
second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
first = datetime.datetime(value, 1, 1)
second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
if settings.USE_TZ:
tz = timezone.get_current_timezone()
first = timezone.make_aware(first, tz)
@@ -644,7 +596,7 @@ class BaseDatabaseOperations:
can vary between backends (e.g., Oracle with %% and &) and between
subexpression types (e.g., date expressions).
"""
conn = " %s " % connector
conn = ' %s ' % connector
return conn.join(sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
@@ -655,7 +607,7 @@ class BaseDatabaseOperations:
Some backends require special syntax to insert binary content (MySQL
for example uses '_binary %s').
"""
return "%s"
return '%s'
def modify_insert_params(self, placeholder, params):
"""
@@ -676,76 +628,66 @@ class BaseDatabaseOperations:
if self.connection.features.supports_temporal_subtraction:
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
return "(%s - %s)" % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params)
raise NotSupportedError(
"This backend does not support %s subtraction." % internal_type
)
return '(%s - %s)' % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params)
raise NotSupportedError("This backend does not support %s subtraction." % internal_type)
def window_frame_start(self, start):
if isinstance(start, int):
if start < 0:
return "%d %s" % (abs(start), self.PRECEDING)
return '%d %s' % (abs(start), self.PRECEDING)
elif start == 0:
return self.CURRENT_ROW
elif start is None:
return self.UNBOUNDED_PRECEDING
raise ValueError(
"start argument must be a negative integer, zero, or None, but got '%s'."
% start
)
raise ValueError("start argument must be a negative integer, zero, or None, but got '%s'." % start)
def window_frame_end(self, end):
if isinstance(end, int):
if end == 0:
return self.CURRENT_ROW
elif end > 0:
return "%d %s" % (end, self.FOLLOWING)
return '%d %s' % (end, self.FOLLOWING)
elif end is None:
return self.UNBOUNDED_FOLLOWING
raise ValueError(
"end argument must be a positive integer, zero, or None, but got '%s'."
% end
)
raise ValueError("end argument must be a positive integer, zero, or None, but got '%s'." % end)
def window_frame_rows_start_end(self, start=None, end=None):
"""
Return SQL for start and end points in an OVER clause window frame.
"""
if not self.connection.features.supports_over_clause:
raise NotSupportedError("This backend does not support window expressions.")
raise NotSupportedError('This backend does not support window expressions.')
return self.window_frame_start(start), self.window_frame_end(end)
def window_frame_range_start_end(self, start=None, end=None):
start_, end_ = self.window_frame_rows_start_end(start, end)
features = self.connection.features
if features.only_supports_unbounded_with_preceding_and_following and (
(start and start < 0) or (end and end > 0)
if (
self.connection.features.only_supports_unbounded_with_preceding_and_following and
((start and start < 0) or (end and end > 0))
):
raise NotSupportedError(
"%s only supports UNBOUNDED together with PRECEDING and "
"FOLLOWING." % self.connection.display_name
'%s only supports UNBOUNDED together with PRECEDING and '
'FOLLOWING.' % self.connection.display_name
)
return start_, end_
def explain_query_prefix(self, format=None, **options):
if not self.connection.features.supports_explaining_query_execution:
raise NotSupportedError(
"This backend does not support explaining query execution."
)
raise NotSupportedError('This backend does not support explaining query execution.')
if format:
supported_formats = self.connection.features.supported_explain_formats
normalized_format = format.upper()
if normalized_format not in supported_formats:
msg = "%s is not a recognized format." % normalized_format
msg = '%s is not a recognized format.' % normalized_format
if supported_formats:
msg += " Allowed formats: %s" % ", ".join(sorted(supported_formats))
msg += ' Allowed formats: %s' % ', '.join(sorted(supported_formats))
raise ValueError(msg)
if options:
raise ValueError("Unknown options: %s" % ", ".join(sorted(options.keys())))
raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
return self.explain_prefix
def insert_statement(self, ignore_conflicts=False):
return "INSERT INTO"
return 'INSERT INTO'
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
return ""
return ''
File diff suppressed because it is too large Load Diff
@@ -1,6 +1,5 @@
class BaseDatabaseValidation:
"""Encapsulate backend-specific validation."""
def __init__(self, connection):
self.connection = connection
@@ -10,12 +9,9 @@ class BaseDatabaseValidation:
def check_field(self, field, **kwargs):
errors = []
# Backends may implement a check_field_type() method.
if (
hasattr(self, "check_field_type")
and
# Ignore any related fields.
not getattr(field, "remote_field", None)
):
if (hasattr(self, 'check_field_type') and
# Ignore any related fields.
not getattr(field, 'remote_field', None)):
# Ignore fields with unsupported features.
db_supports_all_required_features = all(
getattr(self.connection.features, feature, False)
@@ -33,12 +33,10 @@ class Reference:
pass
def __repr__(self):
return "<%s %r>" % (self.__class__.__name__, str(self))
return '<%s %r>' % (self.__class__.__name__, str(self))
def __str__(self):
raise NotImplementedError(
"Subclasses must define how they should be converted to string."
)
raise NotImplementedError('Subclasses must define how they should be converted to string.')
class Table(Reference):
@@ -90,14 +88,12 @@ class Columns(TableColumns):
try:
suffix = self.col_suffixes[idx]
if suffix:
col = "{} {}".format(col, suffix)
col = '{} {}'.format(col, suffix)
except IndexError:
pass
return col
return ", ".join(
col_str(column, idx) for idx, column in enumerate(self.columns)
)
return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
class IndexName(TableColumns):
@@ -121,49 +117,35 @@ class IndexColumns(Columns):
def col_str(column, idx):
# Index.__init__() guarantees that self.opclasses is the same
# length as self.columns.
col = "{} {}".format(self.quote_name(column), self.opclasses[idx])
col = '{} {}'.format(self.quote_name(column), self.opclasses[idx])
try:
suffix = self.col_suffixes[idx]
if suffix:
col = "{} {}".format(col, suffix)
col = '{} {}'.format(col, suffix)
except IndexError:
pass
return col
return ", ".join(
col_str(column, idx) for idx, column in enumerate(self.columns)
)
return ', '.join(col_str(column, idx) for idx, column in enumerate(self.columns))
class ForeignKeyName(TableColumns):
"""Hold a reference to a foreign key name."""
def __init__(
self,
from_table,
from_columns,
to_table,
to_columns,
suffix_template,
create_fk_name,
):
def __init__(self, from_table, from_columns, to_table, to_columns, suffix_template, create_fk_name):
self.to_reference = TableColumns(to_table, to_columns)
self.suffix_template = suffix_template
self.create_fk_name = create_fk_name
super().__init__(
from_table,
from_columns,
)
super().__init__(from_table, from_columns,)
def references_table(self, table):
return super().references_table(table) or self.to_reference.references_table(
table
)
return super().references_table(table) or self.to_reference.references_table(table)
def references_column(self, table, column):
return super().references_column(
table, column
) or self.to_reference.references_column(table, column)
return (
super().references_column(table, column) or
self.to_reference.references_column(table, column)
)
def rename_table_references(self, old_table, new_table):
super().rename_table_references(old_table, new_table)
@@ -175,8 +157,8 @@ class ForeignKeyName(TableColumns):
def __str__(self):
suffix = self.suffix_template % {
"to_table": self.to_reference.table,
"to_column": self.to_reference.columns[0],
'to_table': self.to_reference.table,
'to_column': self.to_reference.columns[0],
}
return self.create_fk_name(self.table, self.columns, suffix)
@@ -189,31 +171,30 @@ class Statement(Reference):
that might have to be adjusted if they're referencing a table or column
that is removed
"""
def __init__(self, template, **parts):
self.template = template
self.parts = parts
def references_table(self, table):
return any(
hasattr(part, "references_table") and part.references_table(table)
hasattr(part, 'references_table') and part.references_table(table)
for part in self.parts.values()
)
def references_column(self, table, column):
return any(
hasattr(part, "references_column") and part.references_column(table, column)
hasattr(part, 'references_column') and part.references_column(table, column)
for part in self.parts.values()
)
def rename_table_references(self, old_table, new_table):
for part in self.parts.values():
if hasattr(part, "rename_table_references"):
if hasattr(part, 'rename_table_references'):
part.rename_table_references(old_table, new_table)
def rename_column_references(self, table, old_column, new_column):
for part in self.parts.values():
if hasattr(part, "rename_column_references"):
if hasattr(part, 'rename_column_references'):
part.rename_column_references(table, old_column, new_column)
def __str__(self):
@@ -225,16 +206,17 @@ class Expressions(TableColumns):
self.compiler = compiler
self.expressions = expressions
self.quote_value = quote_value
columns = [
col.target.column
for col in self.compiler.query._gen_cols([self.expressions])
]
columns = [col.target.column for col in self.compiler.query._gen_cols([self.expressions])]
super().__init__(table, columns)
def rename_table_references(self, old_table, new_table):
if self.table != old_table:
return
self.expressions = self.expressions.relabeled_clone({old_table: new_table})
expressions = deepcopy(self.expressions)
self.columns = []
for col in self.compiler.query._gen_cols([expressions]):
col.alias = new_table
self.expressions = expressions
super().rename_table_references(old_table, new_table)
def rename_column_references(self, table, old_column, new_column):
@@ -17,11 +17,9 @@ from django.db.backends.dummy.features import DummyDatabaseFeatures
def complain(*args, **kwargs):
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the ENGINE value. Check "
"settings documentation for more details."
)
raise ImproperlyConfigured("settings.DATABASES is improperly configured. "
"Please supply the ENGINE value. Check "
"settings documentation for more details.")
def ignore(*args, **kwargs):
@@ -15,7 +15,8 @@ try:
import MySQLdb as Database
except ImportError as err:
raise ImproperlyConfigured(
"Error loading MySQLdb module.\nDid you install mysqlclient?"
'Error loading MySQLdb module.\n'
'Did you install mysqlclient?'
) from err
from MySQLdb.constants import CLIENT, FIELD_TYPE
@@ -32,9 +33,7 @@ from .validation import DatabaseValidation
version = Database.version_info
if version < (1, 4, 0):
raise ImproperlyConfigured(
"mysqlclient 1.4.0 or newer is required; you have %s." % Database.__version__
)
raise ImproperlyConfigured('mysqlclient 1.4.0 or newer is required; you have %s.' % Database.__version__)
# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
@@ -47,7 +46,7 @@ django_conversions = {
# This should match the numerical portion of the version numbers (we can treat
# versions like 5.0.24 and 5.0.24a as the same).
server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
server_version_re = _lazy_re_compile(r'(\d{1,2})\.(\d{1,2})\.(\d{1,2})')
class CursorWrapper:
@@ -58,7 +57,6 @@ class CursorWrapper:
Implemented as a wrapper, rather than a subclass, so that it isn't stuck
to the particular underlying representation returned by Connection.cursor().
"""
codes_for_integrityerror = (
1048, # Column cannot be null
1690, # BIGINT UNSIGNED value is out of range
@@ -98,39 +96,40 @@ class CursorWrapper:
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "mysql"
vendor = 'mysql'
# This dictionary maps Field objects to their associated MySQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
"AutoField": "integer AUTO_INCREMENT",
"BigAutoField": "bigint AUTO_INCREMENT",
"BinaryField": "longblob",
"BooleanField": "bool",
"CharField": "varchar(%(max_length)s)",
"DateField": "date",
"DateTimeField": "datetime(6)",
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
"DurationField": "bigint",
"FileField": "varchar(%(max_length)s)",
"FilePathField": "varchar(%(max_length)s)",
"FloatField": "double precision",
"IntegerField": "integer",
"BigIntegerField": "bigint",
"IPAddressField": "char(15)",
"GenericIPAddressField": "char(39)",
"JSONField": "json",
"OneToOneField": "integer",
"PositiveBigIntegerField": "bigint UNSIGNED",
"PositiveIntegerField": "integer UNSIGNED",
"PositiveSmallIntegerField": "smallint UNSIGNED",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "smallint AUTO_INCREMENT",
"SmallIntegerField": "smallint",
"TextField": "longtext",
"TimeField": "time(6)",
"UUIDField": "char(32)",
'AutoField': 'integer AUTO_INCREMENT',
'BigAutoField': 'bigint AUTO_INCREMENT',
'BinaryField': 'longblob',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime(6)',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'bigint',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'char(15)',
'GenericIPAddressField': 'char(39)',
'JSONField': 'json',
'NullBooleanField': 'bool',
'OneToOneField': 'integer',
'PositiveBigIntegerField': 'bigint UNSIGNED',
'PositiveIntegerField': 'integer UNSIGNED',
'PositiveSmallIntegerField': 'smallint UNSIGNED',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'smallint AUTO_INCREMENT',
'SmallIntegerField': 'smallint',
'TextField': 'longtext',
'TimeField': 'time(6)',
'UUIDField': 'char(32)',
}
# For these data types:
@@ -139,30 +138,23 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# - all versions of MySQL and MariaDB don't support full width database
# indexes
_limited_data_types = (
"tinyblob",
"blob",
"mediumblob",
"longblob",
"tinytext",
"text",
"mediumtext",
"longtext",
"json",
'tinyblob', 'blob', 'mediumblob', 'longblob', 'tinytext', 'text',
'mediumtext', 'longtext', 'json',
)
operators = {
"exact": "= %s",
"iexact": "LIKE %s",
"contains": "LIKE BINARY %s",
"icontains": "LIKE %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": "LIKE BINARY %s",
"endswith": "LIKE BINARY %s",
"istartswith": "LIKE %s",
"iendswith": "LIKE %s",
'exact': '= %s',
'iexact': 'LIKE %s',
'contains': 'LIKE BINARY %s',
'icontains': 'LIKE %s',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': 'LIKE BINARY %s',
'endswith': 'LIKE BINARY %s',
'istartswith': 'LIKE %s',
'iendswith': 'LIKE %s',
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -175,19 +167,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
"contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
"icontains": "LIKE CONCAT('%%', {}, '%%')",
"startswith": "LIKE BINARY CONCAT({}, '%%')",
"istartswith": "LIKE CONCAT({}, '%%')",
"endswith": "LIKE BINARY CONCAT('%%', {})",
"iendswith": "LIKE CONCAT('%%', {})",
'contains': "LIKE BINARY CONCAT('%%', {}, '%%')",
'icontains': "LIKE CONCAT('%%', {}, '%%')",
'startswith': "LIKE BINARY CONCAT({}, '%%')",
'istartswith': "LIKE CONCAT({}, '%%')",
'endswith': "LIKE BINARY CONCAT('%%', {})",
'iendswith': "LIKE CONCAT('%%', {})",
}
isolation_levels = {
"read uncommitted",
"read committed",
"repeatable read",
"serializable",
'read uncommitted',
'read committed',
'repeatable read',
'serializable',
}
Database = Database
@@ -202,39 +194,37 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def get_connection_params(self):
kwargs = {
"conv": django_conversions,
"charset": "utf8",
'conv': django_conversions,
'charset': 'utf8',
}
settings_dict = self.settings_dict
if settings_dict["USER"]:
kwargs["user"] = settings_dict["USER"]
if settings_dict["NAME"]:
kwargs["database"] = settings_dict["NAME"]
if settings_dict["PASSWORD"]:
kwargs["password"] = settings_dict["PASSWORD"]
if settings_dict["HOST"].startswith("/"):
kwargs["unix_socket"] = settings_dict["HOST"]
elif settings_dict["HOST"]:
kwargs["host"] = settings_dict["HOST"]
if settings_dict["PORT"]:
kwargs["port"] = int(settings_dict["PORT"])
if settings_dict['USER']:
kwargs['user'] = settings_dict['USER']
if settings_dict['NAME']:
kwargs['database'] = settings_dict['NAME']
if settings_dict['PASSWORD']:
kwargs['password'] = settings_dict['PASSWORD']
if settings_dict['HOST'].startswith('/'):
kwargs['unix_socket'] = settings_dict['HOST']
elif settings_dict['HOST']:
kwargs['host'] = settings_dict['HOST']
if settings_dict['PORT']:
kwargs['port'] = int(settings_dict['PORT'])
# We need the number of potentially affected rows after an
# "UPDATE", not the number of changed rows.
kwargs["client_flag"] = CLIENT.FOUND_ROWS
kwargs['client_flag'] = CLIENT.FOUND_ROWS
# Validate the transaction isolation level, if specified.
options = settings_dict["OPTIONS"].copy()
isolation_level = options.pop("isolation_level", "read committed")
options = settings_dict['OPTIONS'].copy()
isolation_level = options.pop('isolation_level', 'read committed')
if isolation_level:
isolation_level = isolation_level.lower()
if isolation_level not in self.isolation_levels:
raise ImproperlyConfigured(
"Invalid transaction isolation level '%s' specified.\n"
"Use one of %s, or None."
% (
"Use one of %s, or None." % (
isolation_level,
", ".join("'%s'" % s for s in sorted(self.isolation_levels)),
)
)
', '.join("'%s'" % s for s in sorted(self.isolation_levels))
))
self.isolation_level = isolation_level
kwargs.update(options)
return kwargs
@@ -257,17 +247,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# a recently inserted row will return when the field is tested
# for NULL. Disabling this brings this aspect of MySQL in line
# with SQL standards.
assignments.append("SET SQL_AUTO_IS_NULL = 0")
assignments.append('SET SQL_AUTO_IS_NULL = 0')
if self.isolation_level:
assignments.append(
"SET SESSION TRANSACTION ISOLATION LEVEL %s"
% self.isolation_level.upper()
)
assignments.append('SET SESSION TRANSACTION ISOLATION LEVEL %s' % self.isolation_level.upper())
if assignments:
with self.cursor() as cursor:
cursor.execute("; ".join(assignments))
cursor.execute('; '.join(assignments))
@async_unsafe
def create_cursor(self, name=None):
@@ -291,7 +278,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
need to be re-enabled.
"""
with self.cursor() as cursor:
cursor.execute("SET foreign_key_checks=0")
cursor.execute('SET foreign_key_checks=0')
return True
def enable_constraint_checking(self):
@@ -303,7 +290,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.needs_rollback, needs_rollback = False, self.needs_rollback
try:
with self.cursor() as cursor:
cursor.execute("SET foreign_key_checks=1")
cursor.execute('SET foreign_key_checks=1')
finally:
self.needs_rollback = needs_rollback
@@ -319,48 +306,31 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(
cursor, table_name
)
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for (
column_name,
referenced_table_name,
referenced_column_name,
) in key_columns:
for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
"""
% (
primary_key_column_name,
column_name,
table_name,
referenced_table_name,
column_name,
referenced_column_name,
column_name,
referenced_column_name,
""" % (
primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s."
"The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not "
"have a corresponding value in %s.%s."
% (
table_name,
bad_row[0],
table_name,
column_name,
bad_row[1],
referenced_table_name,
referenced_column_name,
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
@@ -374,20 +344,20 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property
def display_name(self):
return "MariaDB" if self.mysql_is_mariadb else "MySQL"
return 'MariaDB' if self.mysql_is_mariadb else 'MySQL'
@cached_property
def data_type_check_constraints(self):
if self.features.supports_column_check_constraints:
check_constraints = {
"PositiveBigIntegerField": "`%(column)s` >= 0",
"PositiveIntegerField": "`%(column)s` >= 0",
"PositiveSmallIntegerField": "`%(column)s` >= 0",
'PositiveBigIntegerField': '`%(column)s` >= 0',
'PositiveIntegerField': '`%(column)s` >= 0',
'PositiveSmallIntegerField': '`%(column)s` >= 0',
}
if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
# MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
# a check constraint.
check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
check_constraints['JSONField'] = 'JSON_VALID(`%(column)s`)'
return check_constraints
return {}
@@ -397,45 +367,40 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Select some server variables and test if the time zone
# definitions are installed. CONVERT_TZ returns NULL if 'UTC'
# timezone isn't loaded into the mysql.time_zone table.
cursor.execute(
"""
cursor.execute("""
SELECT VERSION(),
@@sql_mode,
@@default_storage_engine,
@@sql_auto_is_null,
@@lower_case_table_names,
CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
"""
)
""")
row = cursor.fetchone()
return {
"version": row[0],
"sql_mode": row[1],
"default_storage_engine": row[2],
"sql_auto_is_null": bool(row[3]),
"lower_case_table_names": bool(row[4]),
"has_zoneinfo_database": bool(row[5]),
'version': row[0],
'sql_mode': row[1],
'default_storage_engine': row[2],
'sql_auto_is_null': bool(row[3]),
'lower_case_table_names': bool(row[4]),
'has_zoneinfo_database': bool(row[5]),
}
@cached_property
def mysql_server_info(self):
return self.mysql_server_data["version"]
return self.mysql_server_data['version']
@cached_property
def mysql_version(self):
match = server_version_re.match(self.mysql_server_info)
if not match:
raise Exception(
"Unable to determine MySQL version from version string %r"
% self.mysql_server_info
)
raise Exception('Unable to determine MySQL version from version string %r' % self.mysql_server_info)
return tuple(int(x) for x in match.groups())
@cached_property
def mysql_is_mariadb(self):
return "mariadb" in self.mysql_server_info.lower()
return 'mariadb' in self.mysql_server_info.lower()
@cached_property
def sql_mode(self):
sql_mode = self.mysql_server_data["sql_mode"]
return set(sql_mode.split(",") if sql_mode else ())
sql_mode = self.mysql_server_data['sql_mode']
return set(sql_mode.split(',') if sql_mode else ())
@@ -2,28 +2,28 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "mysql"
executable_name = 'mysql'
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
env = None
database = settings_dict["OPTIONS"].get(
"database",
settings_dict["OPTIONS"].get("db", settings_dict["NAME"]),
database = settings_dict['OPTIONS'].get(
'database',
settings_dict['OPTIONS'].get('db', settings_dict['NAME']),
)
user = settings_dict["OPTIONS"].get("user", settings_dict["USER"])
password = settings_dict["OPTIONS"].get(
"password",
settings_dict["OPTIONS"].get("passwd", settings_dict["PASSWORD"]),
user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
password = settings_dict['OPTIONS'].get(
'password',
settings_dict['OPTIONS'].get('passwd', settings_dict['PASSWORD'])
)
host = settings_dict["OPTIONS"].get("host", settings_dict["HOST"])
port = settings_dict["OPTIONS"].get("port", settings_dict["PORT"])
server_ca = settings_dict["OPTIONS"].get("ssl", {}).get("ca")
client_cert = settings_dict["OPTIONS"].get("ssl", {}).get("cert")
client_key = settings_dict["OPTIONS"].get("ssl", {}).get("key")
defaults_file = settings_dict["OPTIONS"].get("read_default_file")
charset = settings_dict["OPTIONS"].get("charset")
host = settings_dict['OPTIONS'].get('host', settings_dict['HOST'])
port = settings_dict['OPTIONS'].get('port', settings_dict['PORT'])
server_ca = settings_dict['OPTIONS'].get('ssl', {}).get('ca')
client_cert = settings_dict['OPTIONS'].get('ssl', {}).get('cert')
client_key = settings_dict['OPTIONS'].get('ssl', {}).get('key')
defaults_file = settings_dict['OPTIONS'].get('read_default_file')
charset = settings_dict['OPTIONS'].get('charset')
# Seems to be no good way to set sql_mode with CLI.
if defaults_file:
@@ -38,9 +38,9 @@ class DatabaseClient(BaseDatabaseClient):
# prevents password exposure if the subprocess.run(check=True) call
# raises a CalledProcessError since the string representation of
# the latter includes all of the provided `args`.
env = {"MYSQL_PWD": password}
env = {'MYSQL_PWD': password}
if host:
if "/" in host:
if '/' in host:
args += ["--socket=%s" % host]
else:
args += ["--host=%s" % host]
@@ -53,7 +53,7 @@ class DatabaseClient(BaseDatabaseClient):
if client_key:
args += ["--ssl-key=%s" % client_key]
if charset:
args += ["--default-character-set=%s" % charset]
args += ['--default-character-set=%s' % charset]
if database:
args += [database]
args.extend(parameters)
@@ -8,14 +8,7 @@ class SQLCompiler(compiler.SQLCompiler):
qn = compiler.quote_name_unless_alias
qn2 = self.connection.ops.quote_name
sql, params = self.as_sql()
return (
"(%s) IN (%s)"
% (
", ".join("%s.%s" % (qn(alias), qn2(column)) for column in columns),
sql,
),
params,
)
return '(%s) IN (%s)' % (', '.join('%s.%s' % (qn(alias), qn2(column)) for column in columns), sql), params
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
@@ -34,15 +27,16 @@ class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
# since it doesn't allow for GROUP BY and HAVING clauses.
return super().as_sql()
result = [
"DELETE %s FROM"
% self.quote_name_unless_alias(self.query.get_initial_alias())
'DELETE %s FROM' % self.quote_name_unless_alias(
self.query.get_initial_alias()
)
]
from_sql, from_params = self.get_from_clause()
result.extend(from_sql)
where_sql, where_params = self.compile(where)
if where_sql:
result.append("WHERE %s" % where_sql)
return " ".join(result), tuple(from_params) + tuple(where_params)
result.append('WHERE %s' % where_sql)
return ' '.join(result), tuple(from_params) + tuple(where_params)
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
@@ -56,15 +50,15 @@ class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
try:
for resolved, (sql, params, _) in self.get_order_by():
if (
isinstance(resolved.expression, Col)
and resolved.expression.alias != db_table
isinstance(resolved.expression, Col) and
resolved.expression.alias != db_table
):
# Ignore ordering if it contains joined fields, because
# they cannot be used in the ORDER BY clause.
raise FieldError
order_by_sql.append(sql)
order_by_params.extend(params)
update_query += " ORDER BY " + ", ".join(order_by_sql)
update_query += ' ORDER BY ' + ', '.join(order_by_sql)
update_params += tuple(order_by_params)
except FieldError:
# Ignore ordering if it contains annotations, because they're
@@ -8,14 +8,15 @@ from .client import DatabaseClient
class DatabaseCreation(BaseDatabaseCreation):
def sql_table_creation_suffix(self):
suffix = []
test_settings = self.connection.settings_dict["TEST"]
if test_settings["CHARSET"]:
suffix.append("CHARACTER SET %s" % test_settings["CHARSET"])
if test_settings["COLLATION"]:
suffix.append("COLLATE %s" % test_settings["COLLATION"])
return " ".join(suffix)
test_settings = self.connection.settings_dict['TEST']
if test_settings['CHARSET']:
suffix.append('CHARACTER SET %s' % test_settings['CHARSET'])
if test_settings['COLLATION']:
suffix.append('COLLATE %s' % test_settings['COLLATION'])
return ' '.join(suffix)
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
@@ -23,17 +24,17 @@ class DatabaseCreation(BaseDatabaseCreation):
except Exception as e:
if len(e.args) < 1 or e.args[0] != 1007:
# All errors except "database exists" (1007) cancel tests.
self.log("Got an error creating the test database: %s" % e)
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
else:
raise
def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
test_db_params = {
"dbname": self.connection.ops.quote_name(target_database_name),
"suffix": self.sql_table_creation_suffix(),
'dbname': self.connection.ops.quote_name(target_database_name),
'suffix': self.sql_table_creation_suffix(),
}
with self._nodb_cursor() as cursor:
try:
@@ -44,44 +45,24 @@ class DatabaseCreation(BaseDatabaseCreation):
return
try:
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, target_database_name
),
)
)
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log("Got an error recreating the test database: %s" % e)
self.log('Got an error recreating the test database: %s' % e)
sys.exit(2)
self._clone_db(source_database_name, target_database_name)
def _clone_db(self, source_database_name, target_database_name):
cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(
self.connection.settings_dict, []
)
dump_cmd = [
"mysqldump",
*cmd_args[1:-1],
"--routines",
"--events",
source_database_name,
]
cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(self.connection.settings_dict, [])
dump_cmd = ['mysqldump', *cmd_args[1:-1], '--routines', '--events', source_database_name]
dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
load_cmd = cmd_args
load_cmd[-1] = target_database_name
with subprocess.Popen(
dump_cmd, stdout=subprocess.PIPE, env=dump_env
) as dump_proc:
with subprocess.Popen(
load_cmd,
stdin=dump_proc.stdout,
stdout=subprocess.DEVNULL,
env=load_env,
):
with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE, env=dump_env) as dump_proc:
with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL, env=load_env):
# Allow dump_proc to receive a SIGPIPE if the load process exits.
dump_proc.stdout.close()
@@ -47,119 +47,66 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_order_by_nulls_modifier = False
order_by_nulls_first = True
@cached_property
def test_collations(self):
charset = "utf8"
if self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
10,
6,
):
# utf8 is an alias for utf8mb3 in MariaDB 10.6+.
charset = "utf8mb3"
return {
"ci": f"{charset}_general_ci",
"non_default": f"{charset}_esperanto_ci",
"swedish_ci": f"{charset}_swedish_ci",
}
test_now_utc_template = "UTC_TIMESTAMP"
test_collations = {
'ci': 'utf8_general_ci',
'non_default': 'utf8_esperanto_ci',
'swedish_ci': 'utf8_swedish_ci',
}
@cached_property
def django_test_skips(self):
skips = {
"This doesn't work on MySQL.": {
"db_functions.comparison.test_greatest.GreatestTests."
"test_coalesce_workaround",
"db_functions.comparison.test_least.LeastTests."
"test_coalesce_workaround",
'db_functions.comparison.test_greatest.GreatestTests.test_coalesce_workaround',
'db_functions.comparison.test_least.LeastTests.test_coalesce_workaround',
},
"Running on MySQL requires utf8mb4 encoding (#18392).": {
"model_fields.test_textfield.TextFieldTests.test_emoji",
"model_fields.test_charfield.TestCharField.test_emoji",
'Running on MySQL requires utf8mb4 encoding (#18392).': {
'model_fields.test_textfield.TextFieldTests.test_emoji',
'model_fields.test_charfield.TestCharField.test_emoji',
},
"MySQL doesn't support functional indexes on a function that "
"returns JSON": {
"schema.tests.SchemaTests.test_func_index_json_key_transform",
},
"MySQL supports multiplying and dividing DurationFields by a "
"scalar value but it's not implemented (#25287).": {
"expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide",
'schema.tests.SchemaTests.test_func_index_json_key_transform',
},
}
if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode:
skips.update(
{
"GROUP BY optimization does not work properly when "
"ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #31331.": {
"aggregation.tests.AggregateTestCase."
"test_aggregation_subquery_annotation_multivalued",
"annotations.tests.NonAggregateAnnotationTestCase."
"test_annotation_aggregate_with_m2o",
},
}
)
if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (
8,
if 'ONLY_FULL_GROUP_BY' in self.connection.sql_mode:
skips.update({
'GROUP BY optimization does not work properly when '
'ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #31331.': {
'aggregation.tests.AggregateTestCase.test_aggregation_subquery_annotation_multivalued',
'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o',
},
})
if (
self.connection.mysql_is_mariadb and
(10, 4, 3) < self.connection.mysql_version < (10, 5, 2)
):
skips.update(
{
"Casting to datetime/time is not supported by MySQL < 8.0. "
"(#30224)": {
"aggregation.tests.AggregateTestCase."
"test_aggregation_default_using_time_from_python",
"aggregation.tests.AggregateTestCase."
"test_aggregation_default_using_datetime_from_python",
},
"MySQL < 8.0 returns string type instead of datetime/time. "
"(#30224)": {
"aggregation.tests.AggregateTestCase."
"test_aggregation_default_using_time_from_database",
"aggregation.tests.AggregateTestCase."
"test_aggregation_default_using_datetime_from_database",
},
}
)
if self.connection.mysql_is_mariadb and (
10,
4,
3,
) < self.connection.mysql_version < (10, 5, 2):
skips.update(
{
"https://jira.mariadb.org/browse/MDEV-19598": {
"schema.tests.SchemaTests."
"test_alter_not_unique_field_to_primary_key",
},
}
)
if self.connection.mysql_is_mariadb and (
10,
4,
12,
) < self.connection.mysql_version < (10, 5):
skips.update(
{
"https://jira.mariadb.org/browse/MDEV-22775": {
"schema.tests.SchemaTests."
"test_alter_pk_with_self_referential_field",
},
}
)
skips.update({
'https://jira.mariadb.org/browse/MDEV-19598': {
'schema.tests.SchemaTests.test_alter_not_unique_field_to_primary_key',
},
})
if (
self.connection.mysql_is_mariadb and
(10, 4, 12) < self.connection.mysql_version < (10, 5)
):
skips.update({
'https://jira.mariadb.org/browse/MDEV-22775': {
'schema.tests.SchemaTests.test_alter_pk_with_self_referential_field',
},
})
if not self.supports_explain_analyze:
skips.update(
{
"MariaDB and MySQL >= 8.0.18 specific.": {
"queries.test_explain.ExplainTests.test_mysql_analyze",
},
}
)
skips.update({
'MariaDB and MySQL >= 8.0.18 specific.': {
'queries.test_explain.ExplainTests.test_mysql_analyze',
},
})
return skips
@cached_property
def _mysql_storage_engine(self):
"Internal method used in Django tests. Don't rely on this from your code"
return self.connection.mysql_server_data["default_storage_engine"]
return self.connection.mysql_server_data['default_storage_engine']
@cached_property
def allows_auto_pk_0(self):
@@ -167,50 +114,40 @@ class DatabaseFeatures(BaseDatabaseFeatures):
Autoincrement primary key can be set to 0 if it doesn't generate new
autoincrement values.
"""
return "NO_AUTO_VALUE_ON_ZERO" in self.connection.sql_mode
return 'NO_AUTO_VALUE_ON_ZERO' in self.connection.sql_mode
@cached_property
def update_can_self_select(self):
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
10,
3,
2,
)
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 3, 2)
@cached_property
def can_introspect_foreign_keys(self):
"Confirm support for introspected foreign keys"
return self._mysql_storage_engine != "MyISAM"
return self._mysql_storage_engine != 'MyISAM'
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
"BinaryField": "TextField",
"BooleanField": "IntegerField",
"DurationField": "BigIntegerField",
"GenericIPAddressField": "CharField",
'BinaryField': 'TextField',
'BooleanField': 'IntegerField',
'DurationField': 'BigIntegerField',
'GenericIPAddressField': 'CharField',
}
@cached_property
def can_return_columns_from_insert(self):
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
10,
5,
0,
)
return self.connection.mysql_is_mariadb and self.connection.mysql_version >= (10, 5, 0)
can_return_rows_from_bulk_insert = property(
operator.attrgetter("can_return_columns_from_insert")
)
can_return_rows_from_bulk_insert = property(operator.attrgetter('can_return_columns_from_insert'))
@cached_property
def has_zoneinfo_database(self):
return self.connection.mysql_server_data["has_zoneinfo_database"]
return self.connection.mysql_server_data['has_zoneinfo_database']
@cached_property
def is_sql_auto_is_null_enabled(self):
return self.connection.mysql_server_data["sql_auto_is_null"]
return self.connection.mysql_server_data['sql_auto_is_null']
@cached_property
def supports_over_clause(self):
@@ -218,9 +155,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return True
return self.connection.mysql_version >= (8, 0, 2)
supports_frame_range_fixed_distance = property(
operator.attrgetter("supports_over_clause")
)
supports_frame_range_fixed_distance = property(operator.attrgetter('supports_over_clause'))
@cached_property
def supports_column_check_constraints(self):
@@ -228,26 +163,18 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return self.connection.mysql_version >= (10, 2, 1)
return self.connection.mysql_version >= (8, 0, 16)
supports_table_check_constraints = property(
operator.attrgetter("supports_column_check_constraints")
)
supports_table_check_constraints = property(operator.attrgetter('supports_column_check_constraints'))
@cached_property
def can_introspect_check_constraints(self):
if self.connection.mysql_is_mariadb:
version = self.connection.mysql_version
return (version >= (10, 2, 22) and version < (10, 3)) or version >= (
10,
3,
10,
)
return (version >= (10, 2, 22) and version < (10, 3)) or version >= (10, 3, 10)
return self.connection.mysql_version >= (8, 0, 16)
@cached_property
def has_select_for_update_skip_locked(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 6)
return self.connection.mysql_version >= (8, 0, 1)
return not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 1)
@cached_property
def has_select_for_update_nowait(self):
@@ -257,30 +184,19 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def has_select_for_update_of(self):
return (
not self.connection.mysql_is_mariadb
and self.connection.mysql_version >= (8, 0, 1)
)
return not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 1)
@cached_property
def supports_explain_analyze(self):
return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (
8,
0,
18,
)
return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (8, 0, 18)
@cached_property
def supported_explain_formats(self):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other
# backends.
formats = {"JSON", "TEXT", "TRADITIONAL"}
if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
8,
0,
16,
):
formats.add("TREE")
formats = {'JSON', 'TEXT', 'TRADITIONAL'}
if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (8, 0, 16):
formats.add('TREE')
return formats
@cached_property
@@ -288,11 +204,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"""
All storage engines except MyISAM support transactions.
"""
return self._mysql_storage_engine != "MyISAM"
return self._mysql_storage_engine != 'MyISAM'
@cached_property
def ignores_table_name_case(self):
return self.connection.mysql_server_data["lower_case_table_names"]
return self.connection.mysql_server_data['lower_case_table_names']
@cached_property
def supports_default_in_lead_lag(self):
@@ -314,13 +230,13 @@ class DatabaseFeatures(BaseDatabaseFeatures):
@cached_property
def supports_index_column_ordering(self):
return (
not self.connection.mysql_is_mariadb
and self.connection.mysql_version >= (8, 0, 1)
not self.connection.mysql_is_mariadb and
self.connection.mysql_version >= (8, 0, 1)
)
@cached_property
def supports_expression_indexes(self):
return (
not self.connection.mysql_is_mariadb
and self.connection.mysql_version >= (8, 0, 13)
not self.connection.mysql_is_mariadb and
self.connection.mysql_version >= (8, 0, 13)
)
@@ -3,76 +3,72 @@ from collections import namedtuple
import sqlparse
from MySQLdb.constants import FIELD_TYPE
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
from django.db.models import Index
from django.utils.datastructures import OrderedSet
FieldInfo = namedtuple(
"FieldInfo", BaseFieldInfo._fields + ("extra", "is_unsigned", "has_json_constraint")
)
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('extra', 'is_unsigned', 'has_json_constraint'))
InfoLine = namedtuple(
"InfoLine",
"col_name data_type max_len num_prec num_scale extra column_default "
"collation is_unsigned",
'InfoLine',
'col_name data_type max_len num_prec num_scale extra column_default '
'collation is_unsigned'
)
class DatabaseIntrospection(BaseDatabaseIntrospection):
data_types_reverse = {
FIELD_TYPE.BLOB: "TextField",
FIELD_TYPE.CHAR: "CharField",
FIELD_TYPE.DECIMAL: "DecimalField",
FIELD_TYPE.NEWDECIMAL: "DecimalField",
FIELD_TYPE.DATE: "DateField",
FIELD_TYPE.DATETIME: "DateTimeField",
FIELD_TYPE.DOUBLE: "FloatField",
FIELD_TYPE.FLOAT: "FloatField",
FIELD_TYPE.INT24: "IntegerField",
FIELD_TYPE.JSON: "JSONField",
FIELD_TYPE.LONG: "IntegerField",
FIELD_TYPE.LONGLONG: "BigIntegerField",
FIELD_TYPE.SHORT: "SmallIntegerField",
FIELD_TYPE.STRING: "CharField",
FIELD_TYPE.TIME: "TimeField",
FIELD_TYPE.TIMESTAMP: "DateTimeField",
FIELD_TYPE.TINY: "IntegerField",
FIELD_TYPE.TINY_BLOB: "TextField",
FIELD_TYPE.MEDIUM_BLOB: "TextField",
FIELD_TYPE.LONG_BLOB: "TextField",
FIELD_TYPE.VAR_STRING: "CharField",
FIELD_TYPE.BLOB: 'TextField',
FIELD_TYPE.CHAR: 'CharField',
FIELD_TYPE.DECIMAL: 'DecimalField',
FIELD_TYPE.NEWDECIMAL: 'DecimalField',
FIELD_TYPE.DATE: 'DateField',
FIELD_TYPE.DATETIME: 'DateTimeField',
FIELD_TYPE.DOUBLE: 'FloatField',
FIELD_TYPE.FLOAT: 'FloatField',
FIELD_TYPE.INT24: 'IntegerField',
FIELD_TYPE.JSON: 'JSONField',
FIELD_TYPE.LONG: 'IntegerField',
FIELD_TYPE.LONGLONG: 'BigIntegerField',
FIELD_TYPE.SHORT: 'SmallIntegerField',
FIELD_TYPE.STRING: 'CharField',
FIELD_TYPE.TIME: 'TimeField',
FIELD_TYPE.TIMESTAMP: 'DateTimeField',
FIELD_TYPE.TINY: 'IntegerField',
FIELD_TYPE.TINY_BLOB: 'TextField',
FIELD_TYPE.MEDIUM_BLOB: 'TextField',
FIELD_TYPE.LONG_BLOB: 'TextField',
FIELD_TYPE.VAR_STRING: 'CharField',
}
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if "auto_increment" in description.extra:
if field_type == "IntegerField":
return "AutoField"
elif field_type == "BigIntegerField":
return "BigAutoField"
elif field_type == "SmallIntegerField":
return "SmallAutoField"
if 'auto_increment' in description.extra:
if field_type == 'IntegerField':
return 'AutoField'
elif field_type == 'BigIntegerField':
return 'BigAutoField'
elif field_type == 'SmallIntegerField':
return 'SmallAutoField'
if description.is_unsigned:
if field_type == "BigIntegerField":
return "PositiveBigIntegerField"
elif field_type == "IntegerField":
return "PositiveIntegerField"
elif field_type == "SmallIntegerField":
return "PositiveSmallIntegerField"
if field_type == 'BigIntegerField':
return 'PositiveBigIntegerField'
elif field_type == 'IntegerField':
return 'PositiveIntegerField'
elif field_type == 'SmallIntegerField':
return 'PositiveSmallIntegerField'
# JSON data type is an alias for LONGTEXT in MariaDB, use check
# constraints clauses to introspect JSONField.
if description.has_json_constraint:
return "JSONField"
return 'JSONField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute("SHOW FULL TABLES")
return [
TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]))
for row in cursor.fetchall()
]
return [TableInfo(row[0], {'BASE TABLE': 't', 'VIEW': 'v'}.get(row[1]))
for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"""
@@ -80,44 +76,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
interface."
"""
json_constraints = {}
if (
self.connection.mysql_is_mariadb
and self.connection.features.can_introspect_json_field
):
if self.connection.mysql_is_mariadb and self.connection.features.can_introspect_json_field:
# JSON data type is an alias for LONGTEXT in MariaDB, select
# JSON_VALID() constraints to introspect JSONField.
cursor.execute(
"""
cursor.execute("""
SELECT c.constraint_name AS column_name
FROM information_schema.check_constraints AS c
WHERE
c.table_name = %s AND
LOWER(c.check_clause) =
'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
LOWER(c.check_clause) = 'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
c.constraint_schema = DATABASE()
""",
[table_name],
)
""", [table_name])
json_constraints = {row[0] for row in cursor.fetchall()}
# A default collation for the given table.
cursor.execute(
"""
cursor.execute("""
SELECT table_collation
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = %s
""",
[table_name],
)
""", [table_name])
row = cursor.fetchone()
default_column_collation = row[0] if row else ""
default_column_collation = row[0] if row else ''
# information_schema database gives more accurate results for some figures:
# - varchar length returned by cursor.description is an internal length,
# not visible length (#5725)
# - precision and scale (for decimal fields) (#5014)
# - auto_increment is not available in cursor.description
cursor.execute(
"""
cursor.execute("""
SELECT
column_name, data_type, character_maximum_length,
numeric_precision, numeric_scale, extra, column_default,
@@ -131,14 +116,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
END AS is_unsigned
FROM information_schema.columns
WHERE table_name = %s AND table_schema = DATABASE()
""",
[default_column_collation, table_name],
)
""", [default_column_collation, table_name])
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
cursor.execute(
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
)
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
def to_int(i):
return int(i) if i is not None else i
@@ -146,27 +127,25 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
fields = []
for line in cursor.description:
info = field_info[line[0]]
fields.append(
FieldInfo(
*line[:3],
to_int(info.max_len) or line[3],
to_int(info.num_prec) or line[4],
to_int(info.num_scale) or line[5],
line[6],
info.column_default,
info.collation,
info.extra,
info.is_unsigned,
line[0] in json_constraints,
)
)
fields.append(FieldInfo(
*line[:3],
to_int(info.max_len) or line[3],
to_int(info.num_prec) or line[4],
to_int(info.num_scale) or line[5],
line[6],
info.column_default,
info.collation,
info.extra,
info.is_unsigned,
line[0] in json_constraints,
))
return fields
def get_sequences(self, cursor, table_name, table_fields=()):
for field_info in self.get_table_description(cursor, table_name):
if "auto_increment" in field_info.extra:
if 'auto_increment' in field_info.extra:
# MySQL allows only one auto-increment column per table.
return [{"table": table_name, "column": field_info.name}]
return [{'table': table_name, 'column': field_info.name}]
return []
def get_relations(self, cursor, table_name):
@@ -186,17 +165,13 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for all key columns in the given table.
"""
key_columns = []
cursor.execute(
"""
cursor.execute("""
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL
""",
[table_name],
)
AND referenced_column_name IS NOT NULL""", [table_name])
key_columns.extend(cursor.fetchall())
return key_columns
@@ -206,15 +181,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
storage engine if the table doesn't exist.
"""
cursor.execute(
"""
SELECT engine
FROM information_schema.tables
WHERE
table_name = %s AND
table_schema = DATABASE()
""",
[table_name],
)
"SELECT engine "
"FROM information_schema.tables "
"WHERE table_name = %s", [table_name])
result = cursor.fetchone()
if not result:
return self.connection.features._mysql_storage_engine
@@ -226,9 +195,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
tokens = (token for token in statement.flatten() if not token.is_whitespace)
for token in tokens:
if (
token.ttype == sqlparse.tokens.Name
and self.connection.ops.quote_name(token.value) == token.value
and token.value[1:-1] in columns
token.ttype == sqlparse.tokens.Name and
self.connection.ops.quote_name(token.value) == token.value and
token.value[1:-1] in columns
):
check_columns.add(token.value[1:-1])
return check_columns
@@ -242,39 +211,46 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Get the actual constraint names and columns
name_query = """
SELECT kc.`constraint_name`, kc.`column_name`,
kc.`referenced_table_name`, kc.`referenced_column_name`,
c.`constraint_type`
FROM
information_schema.key_column_usage AS kc,
information_schema.table_constraints AS c
kc.`referenced_table_name`, kc.`referenced_column_name`
FROM information_schema.key_column_usage AS kc
WHERE
kc.table_schema = DATABASE() AND
c.table_schema = kc.table_schema AND
c.constraint_name = kc.constraint_name AND
c.constraint_type != 'CHECK' AND
kc.table_name = %s
ORDER BY kc.`ordinal_position`
"""
cursor.execute(name_query, [table_name])
for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
for constraint, column, ref_table, ref_column in cursor.fetchall():
if constraint not in constraints:
constraints[constraint] = {
"columns": OrderedSet(),
"primary_key": kind == "PRIMARY KEY",
"unique": kind in {"PRIMARY KEY", "UNIQUE"},
"index": False,
"check": False,
"foreign_key": (ref_table, ref_column) if ref_column else None,
'columns': OrderedSet(),
'primary_key': False,
'unique': False,
'index': False,
'check': False,
'foreign_key': (ref_table, ref_column) if ref_column else None,
}
if self.connection.features.supports_index_column_ordering:
constraints[constraint]["orders"] = []
constraints[constraint]["columns"].add(column)
constraints[constraint]['orders'] = []
constraints[constraint]['columns'].add(column)
# Now get the constraint types
type_query = """
SELECT c.constraint_name, c.constraint_type
FROM information_schema.table_constraints AS c
WHERE
c.table_schema = DATABASE() AND
c.table_name = %s
"""
cursor.execute(type_query, [table_name])
for constraint, kind in cursor.fetchall():
if kind.lower() == "primary key":
constraints[constraint]['primary_key'] = True
constraints[constraint]['unique'] = True
elif kind.lower() == "unique":
constraints[constraint]['unique'] = True
# Add check constraints.
if self.connection.features.can_introspect_check_constraints:
unnamed_constraints_index = 0
columns = {
info.name for info in self.get_table_description(cursor, table_name)
}
columns = {info.name for info in self.get_table_description(cursor, table_name)}
if self.connection.mysql_is_mariadb:
type_query = """
SELECT c.constraint_name, c.check_clause
@@ -298,48 +274,42 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"""
cursor.execute(type_query, [table_name])
for constraint, check_clause in cursor.fetchall():
constraint_columns = self._parse_constraint_columns(
check_clause, columns
)
constraint_columns = self._parse_constraint_columns(check_clause, columns)
# Ensure uniqueness of unnamed constraints. Unnamed unique
# and check columns constraints have the same name as
# a column.
if set(constraint_columns) == {constraint}:
unnamed_constraints_index += 1
constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index
constraint = '__unnamed_constraint_%s__' % unnamed_constraints_index
constraints[constraint] = {
"columns": constraint_columns,
"primary_key": False,
"unique": False,
"index": False,
"check": True,
"foreign_key": None,
'columns': constraint_columns,
'primary_key': False,
'unique': False,
'index': False,
'check': True,
'foreign_key': None,
}
# Now add in the indexes
cursor.execute(
"SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)
)
cursor.execute("SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name))
for table, non_unique, index, colseq, column, order, type_ in [
x[:6] + (x[10],) for x in cursor.fetchall()
]:
if index not in constraints:
constraints[index] = {
"columns": OrderedSet(),
"primary_key": False,
"unique": not non_unique,
"check": False,
"foreign_key": None,
'columns': OrderedSet(),
'primary_key': False,
'unique': False,
'check': False,
'foreign_key': None,
}
if self.connection.features.supports_index_column_ordering:
constraints[index]["orders"] = []
constraints[index]["index"] = True
constraints[index]["type"] = (
Index.suffix if type_ == "BTREE" else type_.lower()
)
constraints[index]["columns"].add(column)
constraints[index]['orders'] = []
constraints[index]['index'] = True
constraints[index]['type'] = Index.suffix if type_ == 'BTREE' else type_.lower()
constraints[index]['columns'].add(column)
if self.connection.features.supports_index_column_ordering:
constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
constraints[index]['orders'].append('DESC' if order == 'D' else 'ASC')
# Convert the sorted sets to lists
for constraint in constraints.values():
constraint["columns"] = list(constraint["columns"])
constraint['columns'] = list(constraint['columns'])
return constraints
@@ -2,7 +2,6 @@ import uuid
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.utils import timezone
from django.utils.encoding import force_str
@@ -13,42 +12,42 @@ class DatabaseOperations(BaseDatabaseOperations):
# MySQL stores positive fields as UNSIGNED ints.
integer_field_ranges = {
**BaseDatabaseOperations.integer_field_ranges,
"PositiveSmallIntegerField": (0, 65535),
"PositiveIntegerField": (0, 4294967295),
"PositiveBigIntegerField": (0, 18446744073709551615),
'PositiveSmallIntegerField': (0, 65535),
'PositiveIntegerField': (0, 4294967295),
'PositiveBigIntegerField': (0, 18446744073709551615),
}
cast_data_types = {
"AutoField": "signed integer",
"BigAutoField": "signed integer",
"SmallAutoField": "signed integer",
"CharField": "char(%(max_length)s)",
"DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
"TextField": "char",
"IntegerField": "signed integer",
"BigIntegerField": "signed integer",
"SmallIntegerField": "signed integer",
"PositiveBigIntegerField": "unsigned integer",
"PositiveIntegerField": "unsigned integer",
"PositiveSmallIntegerField": "unsigned integer",
"DurationField": "signed integer",
'AutoField': 'signed integer',
'BigAutoField': 'signed integer',
'SmallAutoField': 'signed integer',
'CharField': 'char(%(max_length)s)',
'DecimalField': 'decimal(%(max_digits)s, %(decimal_places)s)',
'TextField': 'char',
'IntegerField': 'signed integer',
'BigIntegerField': 'signed integer',
'SmallIntegerField': 'signed integer',
'PositiveBigIntegerField': 'unsigned integer',
'PositiveIntegerField': 'unsigned integer',
'PositiveSmallIntegerField': 'unsigned integer',
'DurationField': 'signed integer',
}
cast_char_field_without_max_length = "char"
explain_prefix = "EXPLAIN"
cast_char_field_without_max_length = 'char'
explain_prefix = 'EXPLAIN'
def date_extract_sql(self, lookup_type, field_name):
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
if lookup_type == "week_day":
if lookup_type == 'week_day':
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
return "DAYOFWEEK(%s)" % field_name
elif lookup_type == "iso_week_day":
elif lookup_type == 'iso_week_day':
# WEEKDAY() returns an integer, 0-6, Monday=0.
return "WEEKDAY(%s) + 1" % field_name
elif lookup_type == "week":
elif lookup_type == 'week':
# Override the value of default_week_format for consistency with
# other database backends.
# Mode 3: Monday, 1-53, with 4 or more days this year.
return "WEEK(%s, 3)" % field_name
elif lookup_type == "iso_year":
elif lookup_type == 'iso_year':
# Get the year part from the YEARWEEK function, which returns a
# number as year * 100 + week.
return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name
@@ -59,26 +58,29 @@ class DatabaseOperations(BaseDatabaseOperations):
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = {
"year": "%%Y-01-01",
"month": "%%Y-%%m-01",
'year': '%%Y-01-01',
'month': '%%Y-%%m-01',
} # Use double percents to escape.
if lookup_type in fields:
format_str = fields[lookup_type]
return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str)
elif lookup_type == "quarter":
return (
"MAKEDATE(YEAR(%s), 1) + "
"INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER"
% (field_name, field_name)
elif lookup_type == 'quarter':
return "MAKEDATE(YEAR(%s), 1) + INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" % (
field_name, field_name
)
elif lookup_type == 'week':
return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (
field_name, field_name
)
elif lookup_type == "week":
return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (field_name, field_name)
else:
return "DATE(%s)" % (field_name)
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
return f"{sign}{offset}" if offset else tzname
if '+' in tzname:
return tzname[tzname.find('+'):]
elif '-' in tzname:
return tzname[tzname.find('-'):]
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
@@ -103,23 +105,16 @@ class DatabaseOperations(BaseDatabaseOperations):
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = ["year", "month", "day", "hour", "minute", "second"]
format = (
"%%Y-",
"%%m",
"-%%d",
" %%H:",
"%%i",
":%%s",
) # Use double percents to escape.
format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
if lookup_type == "quarter":
fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s') # Use double percents to escape.
format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')
if lookup_type == 'quarter':
return (
"CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + "
"INTERVAL QUARTER({field_name}) QUARTER - "
+ "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
"INTERVAL QUARTER({field_name}) QUARTER - " +
"INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
).format(field_name=field_name)
if lookup_type == "week":
if lookup_type == 'week':
return (
"CAST(DATE_FORMAT(DATE_SUB({field_name}, "
"INTERVAL WEEKDAY({field_name}) DAY), "
@@ -130,16 +125,16 @@ class DatabaseOperations(BaseDatabaseOperations):
except ValueError:
sql = field_name
else:
format_str = "".join(format[:i] + format_def[i:])
format_str = ''.join(format[:i] + format_def[i:])
sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
return sql
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
fields = {
"hour": "%%H:00:00",
"minute": "%%H:%%i:00",
"second": "%%H:%%i:%%s",
'hour': '%%H:00:00',
'minute': '%%H:%%i:00',
'second': '%%H:%%i:%%s',
} # Use double percents to escape.
if lookup_type in fields:
format_str = fields[lookup_type]
@@ -155,7 +150,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return cursor.fetchall()
def format_for_duration_arithmetic(self, sql):
return "INTERVAL %s MICROSECOND" % sql
return 'INTERVAL %s MICROSECOND' % sql
def force_no_ordering(self):
"""
@@ -173,7 +168,7 @@ class DatabaseOperations(BaseDatabaseOperations):
# attribute where the exact query sent to the database is saved.
# See MySQLdb/cursors.py in the source distribution.
# MySQLdb returns string, PyMySQL bytes.
return force_str(getattr(cursor, "_executed", None), errors="replace")
return force_str(getattr(cursor, '_executed', None), errors='replace')
def no_limit_value(self):
# 2**64 - 1, as recommended by the MySQL documentation
@@ -188,67 +183,58 @@ class DatabaseOperations(BaseDatabaseOperations):
# MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
# statement.
if not fields:
return "", ()
return '', ()
columns = [
"%s.%s"
% (
'%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
) for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()
return 'RETURNING %s' % ', '.join(columns), ()
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
if not tables:
return []
sql = ["SET FOREIGN_KEY_CHECKS = 0;"]
sql = ['SET FOREIGN_KEY_CHECKS = 0;']
if reset_sequences:
# It's faster to TRUNCATE tables that require a sequence reset
# since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE.
sql.extend(
"%s %s;"
% (
style.SQL_KEYWORD("TRUNCATE"),
'%s %s;' % (
style.SQL_KEYWORD('TRUNCATE'),
style.SQL_FIELD(self.quote_name(table_name)),
)
for table_name in tables
) for table_name in tables
)
else:
# Otherwise issue a simple DELETE since it's faster than TRUNCATE
# and preserves sequences.
sql.extend(
"%s %s %s;"
% (
style.SQL_KEYWORD("DELETE"),
style.SQL_KEYWORD("FROM"),
'%s %s %s;' % (
style.SQL_KEYWORD('DELETE'),
style.SQL_KEYWORD('FROM'),
style.SQL_FIELD(self.quote_name(table_name)),
)
for table_name in tables
) for table_name in tables
)
sql.append("SET FOREIGN_KEY_CHECKS = 1;")
sql.append('SET FOREIGN_KEY_CHECKS = 1;')
return sql
def sequence_reset_by_name_sql(self, style, sequences):
return [
"%s %s %s %s = 1;"
% (
style.SQL_KEYWORD("ALTER"),
style.SQL_KEYWORD("TABLE"),
style.SQL_FIELD(self.quote_name(sequence_info["table"])),
style.SQL_FIELD("AUTO_INCREMENT"),
)
for sequence_info in sequences
'%s %s %s %s = 1;' % (
style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(sequence_info['table'])),
style.SQL_FIELD('AUTO_INCREMENT'),
) for sequence_info in sequences
]
def validate_autopk_value(self, value):
# Zero in AUTO_INCREMENT field does not work without the
# NO_AUTO_VALUE_ON_ZERO SQL mode.
if value == 0 and not self.connection.features.allows_auto_pk_0:
raise ValueError(
"The database backend does not accept 0 as a value for AutoField."
)
raise ValueError('The database backend does not accept 0 as a '
'value for AutoField.')
return value
def adapt_datetimefield_value(self, value):
@@ -256,7 +242,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
if hasattr(value, 'resolve_expression'):
return value
# MySQL doesn't support tz-aware datetimes
@@ -264,10 +250,7 @@ class DatabaseOperations(BaseDatabaseOperations):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError(
"MySQL backend does not support timezone-aware datetimes when "
"USE_TZ is False."
)
raise ValueError("MySQL backend does not support timezone-aware datetimes when USE_TZ is False.")
return str(value)
def adapt_timefield_value(self, value):
@@ -275,20 +258,20 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
if hasattr(value, 'resolve_expression'):
return value
# MySQL doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("MySQL backend does not support timezone-aware times.")
return value.isoformat(timespec="microseconds")
return str(value)
def max_name_length(self):
return 64
def pk_default_value(self):
return "NULL"
return 'NULL'
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
@@ -296,27 +279,27 @@ class DatabaseOperations(BaseDatabaseOperations):
return "VALUES " + values_sql
def combine_expression(self, connector, sub_expressions):
if connector == "^":
return "POW(%s)" % ",".join(sub_expressions)
if connector == '^':
return 'POW(%s)' % ','.join(sub_expressions)
# Convert the result to a signed integer since MySQL's binary operators
# return an unsigned integer.
elif connector in ("&", "|", "<<", "#"):
connector = "^" if connector == "#" else connector
return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions)
elif connector == ">>":
elif connector in ('&', '|', '<<', '#'):
connector = '^' if connector == '#' else connector
return 'CONVERT(%s, SIGNED)' % connector.join(sub_expressions)
elif connector == '>>':
lhs, rhs = sub_expressions
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
return 'FLOOR(%(lhs)s / POW(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
return super().combine_expression(connector, sub_expressions)
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == "BooleanField":
if internal_type in ['BooleanField', 'NullBooleanField']:
converters.append(self.convert_booleanfield_value)
elif internal_type == "DateTimeField":
elif internal_type == 'DateTimeField':
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
elif internal_type == "UUIDField":
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
return converters
@@ -336,91 +319,62 @@ class DatabaseOperations(BaseDatabaseOperations):
return value
def binary_placeholder_sql(self, value):
return (
"_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
)
return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'
def subtract_temporals(self, internal_type, lhs, rhs):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
if internal_type == "TimeField":
if internal_type == 'TimeField':
if self.connection.mysql_is_mariadb:
# MariaDB includes the microsecond component in TIME_TO_SEC as
# a decimal. MySQL returns an integer without microseconds.
return (
"CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) "
"* 1000000 AS SIGNED)"
) % {
"lhs": lhs_sql,
"rhs": rhs_sql,
}, (
*lhs_params,
*rhs_params,
)
return 'CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) * 1000000 AS SIGNED)' % {
'lhs': lhs_sql, 'rhs': rhs_sql
}, (*lhs_params, *rhs_params)
return (
"((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
" (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple(
rhs_params
) * 2
) % {'lhs': lhs_sql, 'rhs': rhs_sql}, tuple(lhs_params) * 2 + tuple(rhs_params) * 2
params = (*rhs_params, *lhs_params)
return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
def explain_query_prefix(self, format=None, **options):
# Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
if format and format.upper() == "TEXT":
format = "TRADITIONAL"
elif (
not format and "TREE" in self.connection.features.supported_explain_formats
):
if format and format.upper() == 'TEXT':
format = 'TRADITIONAL'
elif not format and 'TREE' in self.connection.features.supported_explain_formats:
# Use TREE by default (if supported) as it's more informative.
format = "TREE"
analyze = options.pop("analyze", False)
format = 'TREE'
analyze = options.pop('analyze', False)
prefix = super().explain_query_prefix(format, **options)
if analyze and self.connection.features.supports_explain_analyze:
# MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
prefix = (
"ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
)
prefix = 'ANALYZE' if self.connection.mysql_is_mariadb else prefix + ' ANALYZE'
if format and not (analyze and not self.connection.mysql_is_mariadb):
# Only MariaDB supports the analyze option with formats.
prefix += " FORMAT=%s" % format
prefix += ' FORMAT=%s' % format
return prefix
def regex_lookup(self, lookup_type):
# REGEXP BINARY doesn't work correctly in MySQL 8+ and REGEXP_LIKE
# doesn't exist in MySQL 5.x or in MariaDB.
if (
self.connection.mysql_version < (8, 0, 0)
or self.connection.mysql_is_mariadb
):
if lookup_type == "regex":
return "%s REGEXP BINARY %s"
return "%s REGEXP %s"
if self.connection.mysql_version < (8, 0, 0) or self.connection.mysql_is_mariadb:
if lookup_type == 'regex':
return '%s REGEXP BINARY %s'
return '%s REGEXP %s'
match_option = "c" if lookup_type == "regex" else "i"
match_option = 'c' if lookup_type == 'regex' else 'i'
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
def insert_statement(self, ignore_conflicts=False):
return (
"INSERT IGNORE INTO"
if ignore_conflicts
else super().insert_statement(ignore_conflicts)
)
return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
def lookup_cast(self, lookup_type, internal_type=None):
lookup = "%s"
if internal_type == "JSONField":
lookup = '%s'
if internal_type == 'JSONField':
if self.connection.mysql_is_mariadb or lookup_type in (
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"regex",
"iregex",
'iexact', 'contains', 'icontains', 'startswith', 'istartswith',
'endswith', 'iendswith', 'regex', 'iregex',
):
lookup = "JSON_UNQUOTE(%s)"
lookup = 'JSON_UNQUOTE(%s)'
return lookup
@@ -10,26 +10,24 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL"
sql_alter_column_type = "MODIFY %(column)s %(type)s"
sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s"
sql_alter_column_no_default_null = "ALTER COLUMN %(column)s SET DEFAULT NULL"
sql_alter_column_no_default_null = 'ALTER COLUMN %(column)s SET DEFAULT NULL'
# No 'CASCADE' which works as a no-op in MySQL but is undocumented
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s"
sql_create_column_inline_fk = (
", ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) "
"REFERENCES %(to_table)s(%(to_column)s)"
', ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) '
'REFERENCES %(to_table)s(%(to_column)s)'
)
sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s"
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
sql_create_pk = (
"ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
)
sql_create_pk = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
sql_create_index = 'CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s'
@property
def sql_delete_check(self):
@@ -37,8 +35,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# The name of the column check constraint is the same as the field
# name on MariaDB. Adding IF EXISTS clause prevents migrations
# crash. Constraint is removed during a "MODIFY" column statement.
return "ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s"
return "ALTER TABLE %(table)s DROP CHECK %(name)s"
return 'ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s'
return 'ALTER TABLE %(table)s DROP CHECK %(name)s'
@property
def sql_rename_column(self):
@@ -49,26 +47,21 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
return super().sql_rename_column
elif self.connection.mysql_version >= (8, 0, 4):
return super().sql_rename_column
return "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s"
return 'ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s'
def quote_value(self, value):
self.connection.ensure_connection()
if isinstance(value, str):
value = value.replace("%", "%%")
value = value.replace('%', '%%')
# MySQLdb escapes to string, PyMySQL to bytes.
quoted = self.connection.connection.escape(
value, self.connection.connection.encoders
)
quoted = self.connection.connection.escape(value, self.connection.connection.encoders)
if isinstance(value, str) and isinstance(quoted, bytes):
quoted = quoted.decode()
return quoted
def _is_limited_data_type(self, field):
db_type = field.db_type(self.connection)
return (
db_type is not None
and db_type.lower() in self.connection._limited_data_types
)
return db_type is not None and db_type.lower() in self.connection._limited_data_types
def skip_default(self, field):
if not self._supports_limited_data_type_defaults:
@@ -92,13 +85,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _column_default_sql(self, field):
if (
not self.connection.mysql_is_mariadb
and self._supports_limited_data_type_defaults
and self._is_limited_data_type(field)
not self.connection.mysql_is_mariadb and
self._supports_limited_data_type_defaults and
self._is_limited_data_type(field)
):
# MySQL supports defaults for BLOB and TEXT columns only if the
# default value is written as an expression i.e. in parentheses.
return "(%s)"
return '(%s)'
return super()._column_default_sql(field)
def add_field(self, model, field):
@@ -108,32 +101,25 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# field.default may be unhashable, so a set isn't used for "in" check.
if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
effective_default = self.effective_default(field)
self.execute(
"UPDATE %(table)s SET %(column)s = %%s"
% {
"table": self.quote_name(model._meta.db_table),
"column": self.quote_name(field.column),
},
[effective_default],
)
self.execute('UPDATE %(table)s SET %(column)s = %%s' % {
'table': self.quote_name(model._meta.db_table),
'column': self.quote_name(field.column),
}, [effective_default])
def _field_should_be_indexed(self, model, field):
if not super()._field_should_be_indexed(model, field):
return False
create_index = super()._field_should_be_indexed(model, field)
storage = self.connection.introspection.get_storage_engine(
self.connection.cursor(), model._meta.db_table
)
# No need to create an index for ForeignKey fields except if
# db_constraint=False because the index from that constraint won't be
# created.
if (
storage == "InnoDB"
and field.get_internal_type() == "ForeignKey"
and field.db_constraint
):
if (storage == "InnoDB" and
create_index and
field.get_internal_type() == 'ForeignKey' and
field.db_constraint):
return False
return not self._is_limited_data_type(field)
return not self._is_limited_data_type(field) and create_index
def _delete_composed_index(self, model, fields, *args):
"""
@@ -145,13 +131,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
recreate a FK index.
"""
first_field = model._meta.get_field(fields[0])
if first_field.get_internal_type() == "ForeignKey":
constraint_names = self._constraint_names(
model, [first_field.column], index=True
)
if first_field.get_internal_type() == 'ForeignKey':
constraint_names = self._constraint_names(model, [first_field.column], index=True)
if not constraint_names:
self.execute(
self._create_index_sql(model, fields=[first_field], suffix="")
self._create_index_sql(model, fields=[first_field], suffix='')
)
return super()._delete_composed_index(model, fields, *args)
@@ -10,29 +10,24 @@ class DatabaseValidation(BaseDatabaseValidation):
return issues
def _check_sql_mode(self, **kwargs):
if not (
self.connection.sql_mode & {"STRICT_TRANS_TABLES", "STRICT_ALL_TABLES"}
):
return [
checks.Warning(
"%s Strict Mode is not set for database connection '%s'"
% (self.connection.display_name, self.connection.alias),
hint=(
"%s's Strict Mode fixes many data integrity problems in "
"%s, such as data truncation upon insertion, by "
"escalating warnings into errors. It is strongly "
"recommended you activate it. See: "
"https://docs.djangoproject.com/en/%s/ref/databases/"
"#mysql-sql-mode"
% (
self.connection.display_name,
self.connection.display_name,
get_docs_version(),
),
if not (self.connection.sql_mode & {'STRICT_TRANS_TABLES', 'STRICT_ALL_TABLES'}):
return [checks.Warning(
"%s Strict Mode is not set for database connection '%s'"
% (self.connection.display_name, self.connection.alias),
hint=(
"%s's Strict Mode fixes many data integrity problems in "
"%s, such as data truncation upon insertion, by "
"escalating warnings into errors. It is strongly "
"recommended you activate it. See: "
"https://docs.djangoproject.com/en/%s/ref/databases/#mysql-sql-mode"
% (
self.connection.display_name,
self.connection.display_name,
get_docs_version(),
),
id="mysql.W002",
)
]
),
id='mysql.W002',
)]
return []
def check_field_type(self, field, field_type):
@@ -43,35 +38,32 @@ class DatabaseValidation(BaseDatabaseValidation):
MySQL doesn't support a database index on some data types.
"""
errors = []
if (
field_type.startswith("varchar")
and field.unique
and (field.max_length is None or int(field.max_length) > 255)
):
if (field_type.startswith('varchar') and field.unique and
(field.max_length is None or int(field.max_length) > 255)):
errors.append(
checks.Warning(
"%s may not allow unique CharFields to have a max_length "
"> 255." % self.connection.display_name,
'%s may not allow unique CharFields to have a max_length '
'> 255.' % self.connection.display_name,
obj=field,
hint=(
"See: https://docs.djangoproject.com/en/%s/ref/"
"databases/#mysql-character-fields" % get_docs_version()
'See: https://docs.djangoproject.com/en/%s/ref/'
'databases/#mysql-character-fields' % get_docs_version()
),
id="mysql.W003",
id='mysql.W003',
)
)
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
"%s does not support a database index on %s columns."
'%s does not support a database index on %s columns.'
% (self.connection.display_name, field_type),
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
id="fields.W162",
id='fields.W162',
)
)
return errors
@@ -21,31 +21,27 @@ from django.utils.functional import cached_property
def _setup_environment(environ):
# Cygwin requires some special voodoo to set the environment variables
# properly so that Oracle will see them.
if platform.system().upper().startswith("CYGWIN"):
if platform.system().upper().startswith('CYGWIN'):
try:
import ctypes
except ImportError as e:
raise ImproperlyConfigured(
"Error loading ctypes: %s; "
"the Oracle backend requires ctypes to "
"operate correctly under Cygwin." % e
)
kernel32 = ctypes.CDLL("kernel32")
raise ImproperlyConfigured("Error loading ctypes: %s; "
"the Oracle backend requires ctypes to "
"operate correctly under Cygwin." % e)
kernel32 = ctypes.CDLL('kernel32')
for name, value in environ:
kernel32.SetEnvironmentVariableA(name, value)
else:
os.environ.update(environ)
_setup_environment(
[
# Oracle takes client-side character set encoding from the environment.
("NLS_LANG", ".AL32UTF8"),
# This prevents Unicode from getting mangled by getting encoded into the
# potentially non-Unicode database character set.
("ORA_NCHAR_LITERAL_REPLACE", "TRUE"),
]
)
_setup_environment([
# Oracle takes client-side character set encoding from the environment.
('NLS_LANG', '.AL32UTF8'),
# This prevents Unicode from getting mangled by getting encoded into the
# potentially non-Unicode database character set.
('ORA_NCHAR_LITERAL_REPLACE', 'TRUE'),
])
try:
@@ -81,16 +77,17 @@ def wrap_oracle_errors():
# Convert that case to Django's IntegrityError exception.
x = e.args[0]
if (
hasattr(x, "code")
and hasattr(x, "message")
and x.code == 2091
and ("ORA-02291" in x.message or "ORA-00001" in x.message)
hasattr(x, 'code') and
hasattr(x, 'message') and
x.code == 2091 and
('ORA-02291' in x.message or 'ORA-00001' in x.message)
):
raise IntegrityError(*tuple(e.args))
raise
class _UninitializedOperatorsDescriptor:
def __get__(self, instance, cls=None):
# If connection.operators is looked up before a connection has been
# created, transparently initialize connection.operators to avert an
@@ -99,12 +96,12 @@ class _UninitializedOperatorsDescriptor:
raise AttributeError("operators not available as class attribute")
# Creating a cursor will initialize the operators.
instance.cursor().close()
return instance.__dict__["operators"]
return instance.__dict__['operators']
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "oracle"
display_name = "Oracle"
vendor = 'oracle'
display_name = 'Oracle'
# This dictionary maps Field objects to their associated Oracle column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
@@ -113,86 +110,73 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Any format strings starting with "qn_" are quoted before being used in the
# output (the "qn_" prefix is stripped before the lookup is performed.
data_types = {
"AutoField": "NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY",
"BigAutoField": "NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY",
"BinaryField": "BLOB",
"BooleanField": "NUMBER(1)",
"CharField": "NVARCHAR2(%(max_length)s)",
"DateField": "DATE",
"DateTimeField": "TIMESTAMP",
"DecimalField": "NUMBER(%(max_digits)s, %(decimal_places)s)",
"DurationField": "INTERVAL DAY(9) TO SECOND(6)",
"FileField": "NVARCHAR2(%(max_length)s)",
"FilePathField": "NVARCHAR2(%(max_length)s)",
"FloatField": "DOUBLE PRECISION",
"IntegerField": "NUMBER(11)",
"JSONField": "NCLOB",
"BigIntegerField": "NUMBER(19)",
"IPAddressField": "VARCHAR2(15)",
"GenericIPAddressField": "VARCHAR2(39)",
"OneToOneField": "NUMBER(11)",
"PositiveBigIntegerField": "NUMBER(19)",
"PositiveIntegerField": "NUMBER(11)",
"PositiveSmallIntegerField": "NUMBER(11)",
"SlugField": "NVARCHAR2(%(max_length)s)",
"SmallAutoField": "NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY",
"SmallIntegerField": "NUMBER(11)",
"TextField": "NCLOB",
"TimeField": "TIMESTAMP",
"URLField": "VARCHAR2(%(max_length)s)",
"UUIDField": "VARCHAR2(32)",
'AutoField': 'NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY',
'BigAutoField': 'NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY',
'BinaryField': 'BLOB',
'BooleanField': 'NUMBER(1)',
'CharField': 'NVARCHAR2(%(max_length)s)',
'DateField': 'DATE',
'DateTimeField': 'TIMESTAMP',
'DecimalField': 'NUMBER(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'INTERVAL DAY(9) TO SECOND(6)',
'FileField': 'NVARCHAR2(%(max_length)s)',
'FilePathField': 'NVARCHAR2(%(max_length)s)',
'FloatField': 'DOUBLE PRECISION',
'IntegerField': 'NUMBER(11)',
'JSONField': 'NCLOB',
'BigIntegerField': 'NUMBER(19)',
'IPAddressField': 'VARCHAR2(15)',
'GenericIPAddressField': 'VARCHAR2(39)',
'NullBooleanField': 'NUMBER(1)',
'OneToOneField': 'NUMBER(11)',
'PositiveBigIntegerField': 'NUMBER(19)',
'PositiveIntegerField': 'NUMBER(11)',
'PositiveSmallIntegerField': 'NUMBER(11)',
'SlugField': 'NVARCHAR2(%(max_length)s)',
'SmallAutoField': 'NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY',
'SmallIntegerField': 'NUMBER(11)',
'TextField': 'NCLOB',
'TimeField': 'TIMESTAMP',
'URLField': 'VARCHAR2(%(max_length)s)',
'UUIDField': 'VARCHAR2(32)',
}
data_type_check_constraints = {
"BooleanField": "%(qn_column)s IN (0,1)",
"JSONField": "%(qn_column)s IS JSON",
"PositiveBigIntegerField": "%(qn_column)s >= 0",
"PositiveIntegerField": "%(qn_column)s >= 0",
"PositiveSmallIntegerField": "%(qn_column)s >= 0",
'BooleanField': '%(qn_column)s IN (0,1)',
'JSONField': '%(qn_column)s IS JSON',
'NullBooleanField': '%(qn_column)s IN (0,1)',
'PositiveBigIntegerField': '%(qn_column)s >= 0',
'PositiveIntegerField': '%(qn_column)s >= 0',
'PositiveSmallIntegerField': '%(qn_column)s >= 0',
}
# Oracle doesn't support a database index on these columns.
_limited_data_types = ("clob", "nclob", "blob")
_limited_data_types = ('clob', 'nclob', 'blob')
operators = _UninitializedOperatorsDescriptor()
_standard_operators = {
"exact": "= %s",
"iexact": "= UPPER(%s)",
"contains": (
"LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
),
"icontains": (
"LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) "
"ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
),
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": (
"LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
),
"endswith": (
"LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
),
"istartswith": (
"LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) "
"ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
),
"iendswith": (
"LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) "
"ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
),
'exact': '= %s',
'iexact': '= UPPER(%s)',
'contains': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'icontains': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'endswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'istartswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
'iendswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
}
_likec_operators = {
**_standard_operators,
"contains": "LIKEC %s ESCAPE '\\'",
"icontains": "LIKEC UPPER(%s) ESCAPE '\\'",
"startswith": "LIKEC %s ESCAPE '\\'",
"endswith": "LIKEC %s ESCAPE '\\'",
"istartswith": "LIKEC UPPER(%s) ESCAPE '\\'",
"iendswith": "LIKEC UPPER(%s) ESCAPE '\\'",
'contains': "LIKEC %s ESCAPE '\\'",
'icontains': "LIKEC UPPER(%s) ESCAPE '\\'",
'startswith': "LIKEC %s ESCAPE '\\'",
'endswith': "LIKEC %s ESCAPE '\\'",
'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'",
'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'",
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -205,22 +189,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
_pattern_ops = {
"contains": "'%%' || {} || '%%'",
"icontains": "'%%' || UPPER({}) || '%%'",
"startswith": "{} || '%%'",
"istartswith": "UPPER({}) || '%%'",
"endswith": "'%%' || {}",
"iendswith": "'%%' || UPPER({})",
'contains': "'%%' || {} || '%%'",
'icontains': "'%%' || UPPER({}) || '%%'",
'startswith': "{} || '%%'",
'istartswith': "UPPER({}) || '%%'",
'endswith': "'%%' || {}",
'iendswith': "'%%' || UPPER({})",
}
_standard_pattern_ops = {
k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)"
" ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
for k, v in _pattern_ops.items()
}
_likec_pattern_ops = {
k: "LIKEC " + v + " ESCAPE '\\'" for k, v in _pattern_ops.items()
}
_standard_pattern_ops = {k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)"
" ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
for k, v in _pattern_ops.items()}
_likec_pattern_ops = {k: "LIKEC " + v + " ESCAPE '\\'"
for k, v in _pattern_ops.items()}
Database = Database
SchemaEditorClass = DatabaseSchemaEditor
@@ -234,22 +215,20 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
use_returning_into = self.settings_dict["OPTIONS"].get(
"use_returning_into", True
)
use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True)
self.features.can_return_columns_from_insert = use_returning_into
def get_connection_params(self):
conn_params = self.settings_dict["OPTIONS"].copy()
if "use_returning_into" in conn_params:
del conn_params["use_returning_into"]
conn_params = self.settings_dict['OPTIONS'].copy()
if 'use_returning_into' in conn_params:
del conn_params['use_returning_into']
return conn_params
@async_unsafe
def get_new_connection(self, conn_params):
return Database.connect(
user=self.settings_dict["USER"],
password=self.settings_dict["PASSWORD"],
user=self.settings_dict['USER'],
password=self.settings_dict['PASSWORD'],
dsn=dsn(self.settings_dict),
**conn_params,
)
@@ -267,11 +246,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# TO_CHAR().
cursor.execute(
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'"
+ (" TIME_ZONE = 'UTC'" if settings.USE_TZ else "")
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'" +
(" TIME_ZONE = 'UTC'" if settings.USE_TZ else '')
)
cursor.close()
if "operators" not in self.__dict__:
if 'operators' not in self.__dict__:
# Ticket #14149: Check whether our LIKE implementation will
# work for this connection or we need to fall back on LIKEC.
# This check is performed only once per DatabaseWrapper
@@ -279,11 +258,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# the same settings.
cursor = self.create_cursor()
try:
cursor.execute(
"SELECT 1 FROM DUAL WHERE DUMMY %s"
% self._standard_operators["contains"],
["X"],
)
cursor.execute("SELECT 1 FROM DUAL WHERE DUMMY %s"
% self._standard_operators['contains'],
['X'])
except Database.DatabaseError:
self.operators = self._likec_operators
self.pattern_ops = self._likec_pattern_ops
@@ -309,12 +286,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# logging is enabled to keep query counts consistent with other backends.
def _savepoint_commit(self, sid):
if self.queries_logged:
self.queries_log.append(
{
"sql": "-- RELEASE SAVEPOINT %s (faked)" % self.ops.quote_name(sid),
"time": "0.000",
}
)
self.queries_log.append({
'sql': '-- RELEASE SAVEPOINT %s (faked)' % self.ops.quote_name(sid),
'time': '0.000',
})
def _set_autocommit(self, autocommit):
with self.wrap_database_errors:
@@ -326,8 +301,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
afterward.
"""
with self.cursor() as cursor:
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
@@ -339,12 +314,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@cached_property
def cx_oracle_version(self):
return tuple(int(x) for x in Database.version.split("."))
return tuple(int(x) for x in Database.version.split('.'))
@cached_property
def oracle_version(self):
with self.temporary_connection():
return tuple(int(x) for x in self.connection.version.split("."))
return tuple(int(x) for x in self.connection.version.split('.'))
class OracleParam:
@@ -360,10 +335,8 @@ class OracleParam:
def __init__(self, param, cursor, strings_only=False):
# With raw SQL queries, datetimes can reach this function
# without being converted by DateTimeField.get_db_prep_value.
if settings.USE_TZ and (
isinstance(param, datetime.datetime)
and not isinstance(param, Oracle_datetime)
):
if settings.USE_TZ and (isinstance(param, datetime.datetime) and
not isinstance(param, Oracle_datetime)):
param = Oracle_datetime.from_datetime(param)
string_size = 0
@@ -372,7 +345,7 @@ class OracleParam:
param = 1
elif param is False:
param = 0
if hasattr(param, "bind_parameter"):
if hasattr(param, 'bind_parameter'):
self.force_bytes = param.bind_parameter(cursor)
elif isinstance(param, (Database.Binary, datetime.timedelta)):
self.force_bytes = param
@@ -383,7 +356,7 @@ class OracleParam:
if isinstance(self.force_bytes, str):
# We could optimize by only converting up to 4000 bytes here
string_size = len(force_bytes(param, cursor.charset, strings_only))
if hasattr(param, "input_size"):
if hasattr(param, 'input_size'):
# If parameter has `input_size` attribute, use that.
self.input_size = param.input_size
elif string_size > 4000:
@@ -413,7 +386,7 @@ class VariableWrapper:
return getattr(self.var, key)
def __setattr__(self, key, value):
if key == "var":
if key == 'var':
self.__dict__[key] = value
else:
setattr(self.var, key, value)
@@ -425,8 +398,7 @@ class FormatStylePlaceholderCursor:
style. This fixes it -- but note that if you want to use a literal "%s" in
a query, you'll need to use "%%s".
"""
charset = "utf-8"
charset = 'utf-8'
def __init__(self, connection):
self.cursor = connection.cursor()
@@ -434,7 +406,7 @@ class FormatStylePlaceholderCursor:
@staticmethod
def _output_number_converter(value):
return decimal.Decimal(value) if "." in value else int(value)
return decimal.Decimal(value) if '.' in value else int(value)
@staticmethod
def _get_decimal_converter(precision, scale):
@@ -464,9 +436,7 @@ class FormatStylePlaceholderCursor:
elif precision > 0:
# NUMBER(p,s) column: decimal-precision fixed point.
# This comes from IntegerField and DecimalField columns.
outconverter = FormatStylePlaceholderCursor._get_decimal_converter(
precision, scale
)
outconverter = FormatStylePlaceholderCursor._get_decimal_converter(precision, scale)
else:
# No type information. This normally comes from a
# mathematical expression in the SELECT list. Guess int
@@ -487,7 +457,7 @@ class FormatStylePlaceholderCursor:
def _guess_input_sizes(self, params_list):
# Try dict handling; if that fails, treat as sequence
if hasattr(params_list[0], "keys"):
if hasattr(params_list[0], 'keys'):
sizes = {}
for params in params_list:
for k, value in params.items():
@@ -507,7 +477,7 @@ class FormatStylePlaceholderCursor:
def _param_generator(self, params):
# Try dict handling; if that fails, treat as sequence
if hasattr(params, "items"):
if hasattr(params, 'items'):
return {k: v.force_bytes for k, v in params.items()}
else:
return [p.force_bytes for p in params]
@@ -517,11 +487,11 @@ class FormatStylePlaceholderCursor:
# it does want a trailing ';' but not a trailing '/'. However, these
# characters must be included in the original query in case the query
# is being passed to SQL*Plus.
if query.endswith(";") or query.endswith("/"):
if query.endswith(';') or query.endswith('/'):
query = query[:-1]
if params is None:
params = []
elif hasattr(params, "keys"):
elif hasattr(params, 'keys'):
# Handle params as dict
args = {k: ":%s" % k for k in params}
query = query % args
@@ -534,14 +504,15 @@ class FormatStylePlaceholderCursor:
# args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0']
# params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'}
params_dict = {
param: ":arg%d" % i for i, param in enumerate(dict.fromkeys(params))
param: ':arg%d' % i
for i, param in enumerate(dict.fromkeys(params))
}
args = [params_dict[param] for param in params]
params = {value: key for key, value in params_dict.items()}
query = query % tuple(args)
else:
# Handle params as sequence
args = [(":arg%d" % i) for i in range(len(params))]
args = [(':arg%d' % i) for i in range(len(params))]
query = query % tuple(args)
return query, self._format_params(params)
@@ -563,9 +534,7 @@ class FormatStylePlaceholderCursor:
formatted = [firstparams] + [self._format_params(p) for p in params_iter]
self._guess_input_sizes(formatted)
with wrap_oracle_errors():
return self.cursor.executemany(
query, [self._param_generator(p) for p in formatted]
)
return self.cursor.executemany(query, [self._param_generator(p) for p in formatted])
def close(self):
try:
@@ -4,22 +4,22 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "sqlplus"
wrapper_name = "rlwrap"
executable_name = 'sqlplus'
wrapper_name = 'rlwrap'
@staticmethod
def connect_string(settings_dict):
from django.db.backends.oracle.utils import dsn
return '%s/"%s"@%s' % (
settings_dict["USER"],
settings_dict["PASSWORD"],
settings_dict['USER'],
settings_dict['PASSWORD'],
dsn(settings_dict),
)
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name, "-L", cls.connect_string(settings_dict)]
args = [cls.executable_name, '-L', cls.connect_string(settings_dict)]
wrapper_path = shutil.which(cls.wrapper_name)
if wrapper_path:
args = [wrapper_path, *args]
@@ -6,10 +6,11 @@ from django.db.backends.base.creation import BaseDatabaseCreation
from django.utils.crypto import get_random_string
from django.utils.functional import cached_property
TEST_DATABASE_PREFIX = "test_"
TEST_DATABASE_PREFIX = 'test_'
class DatabaseCreation(BaseDatabaseCreation):
@cached_property
def _maindb_connection(self):
"""
@@ -20,9 +21,9 @@ class DatabaseCreation(BaseDatabaseCreation):
is the main (non-test) connection.
"""
settings_dict = settings.DATABASES[self.connection.alias]
user = settings_dict.get("SAVED_USER") or settings_dict["USER"]
password = settings_dict.get("SAVED_PASSWORD") or settings_dict["PASSWORD"]
settings_dict = {**settings_dict, "USER": user, "PASSWORD": password}
user = settings_dict.get('SAVED_USER') or settings_dict['USER']
password = settings_dict.get('SAVED_PASSWORD') or settings_dict['PASSWORD']
settings_dict = {**settings_dict, 'USER': user, 'PASSWORD': password}
DatabaseWrapper = type(self.connection)
return DatabaseWrapper(settings_dict, alias=self.connection.alias)
@@ -31,97 +32,72 @@ class DatabaseCreation(BaseDatabaseCreation):
with self._maindb_connection.cursor() as cursor:
if self._test_database_create():
try:
self._execute_test_db_creation(
cursor, parameters, verbosity, keepdb
)
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
except Exception as e:
if "ORA-01543" not in str(e):
if 'ORA-01543' not in str(e):
# All errors except "tablespace already exists" cancel tests
self.log("Got an error creating the test database: %s" % e)
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test database, %s, already exists. "
"Type 'yes' to delete it, or 'no' to cancel: "
% parameters["user"]
)
if autoclobber or confirm == "yes":
"Type 'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes':
if verbosity >= 1:
self.log(
"Destroying old test database for alias '%s'..."
% self.connection.alias
)
self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
try:
self._execute_test_db_destruction(
cursor, parameters, verbosity
)
self._execute_test_db_destruction(cursor, parameters, verbosity)
except DatabaseError as e:
if "ORA-29857" in str(e):
self._handle_objects_preventing_db_destruction(
cursor, parameters, verbosity, autoclobber
)
if 'ORA-29857' in str(e):
self._handle_objects_preventing_db_destruction(cursor, parameters,
verbosity, autoclobber)
else:
# Ran into a database error that isn't about
# leftover objects in the tablespace.
self.log(
"Got an error destroying the old test database: %s"
% e
)
# Ran into a database error that isn't about leftover objects in the tablespace
self.log('Got an error destroying the old test database: %s' % e)
sys.exit(2)
except Exception as e:
self.log(
"Got an error destroying the old test database: %s" % e
)
self.log('Got an error destroying the old test database: %s' % e)
sys.exit(2)
try:
self._execute_test_db_creation(
cursor, parameters, verbosity, keepdb
)
self._execute_test_db_creation(cursor, parameters, verbosity, keepdb)
except Exception as e:
self.log(
"Got an error recreating the test database: %s" % e
)
self.log('Got an error recreating the test database: %s' % e)
sys.exit(2)
else:
self.log("Tests cancelled.")
self.log('Tests cancelled.')
sys.exit(1)
if self._test_user_create():
if verbosity >= 1:
self.log("Creating test user...")
self.log('Creating test user...')
try:
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
if "ORA-01920" not in str(e):
if 'ORA-01920' not in str(e):
# All errors except "user already exists" cancel tests
self.log("Got an error creating the test user: %s" % e)
self.log('Got an error creating the test user: %s' % e)
sys.exit(2)
if not autoclobber:
confirm = input(
"It appears the test user, %s, already exists. Type "
"'yes' to delete it, or 'no' to cancel: "
% parameters["user"]
)
if autoclobber or confirm == "yes":
"'yes' to delete it, or 'no' to cancel: " % parameters['user'])
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
self.log("Destroying old test user...")
self.log('Destroying old test user...')
self._destroy_test_user(cursor, parameters, verbosity)
if verbosity >= 1:
self.log("Creating test user...")
self._create_test_user(
cursor, parameters, verbosity, keepdb
)
self.log('Creating test user...')
self._create_test_user(cursor, parameters, verbosity, keepdb)
except Exception as e:
self.log("Got an error recreating the test user: %s" % e)
self.log('Got an error recreating the test user: %s' % e)
sys.exit(2)
else:
self.log("Tests cancelled.")
self.log('Tests cancelled.')
sys.exit(1)
# Done with main user -- test user and tablespaces created.
self._maindb_connection.close()
self._maindb_connection.close() # done with main user -- test user and tablespaces created
self._switch_to_test_user(parameters)
return self.connection.settings_dict["NAME"]
return self.connection.settings_dict['NAME']
def _switch_to_test_user(self, parameters):
"""
@@ -133,71 +109,59 @@ class DatabaseCreation(BaseDatabaseCreation):
credentials in the SAVED_USER/SAVED_PASSWORD key in the settings dict.
"""
real_settings = settings.DATABASES[self.connection.alias]
real_settings["SAVED_USER"] = self.connection.settings_dict[
"SAVED_USER"
] = self.connection.settings_dict["USER"]
real_settings["SAVED_PASSWORD"] = self.connection.settings_dict[
"SAVED_PASSWORD"
] = self.connection.settings_dict["PASSWORD"]
real_test_settings = real_settings["TEST"]
test_settings = self.connection.settings_dict["TEST"]
real_test_settings["USER"] = real_settings["USER"] = test_settings[
"USER"
] = self.connection.settings_dict["USER"] = parameters["user"]
real_settings["PASSWORD"] = self.connection.settings_dict[
"PASSWORD"
] = parameters["password"]
real_settings['SAVED_USER'] = self.connection.settings_dict['SAVED_USER'] = \
self.connection.settings_dict['USER']
real_settings['SAVED_PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD'] = \
self.connection.settings_dict['PASSWORD']
real_test_settings = real_settings['TEST']
test_settings = self.connection.settings_dict['TEST']
real_test_settings['USER'] = real_settings['USER'] = test_settings['USER'] = \
self.connection.settings_dict['USER'] = parameters['user']
real_settings['PASSWORD'] = self.connection.settings_dict['PASSWORD'] = parameters['password']
def set_as_test_mirror(self, primary_settings_dict):
"""
Set this database up to be used in testing as a mirror of a primary
database whose settings are given.
"""
self.connection.settings_dict["USER"] = primary_settings_dict["USER"]
self.connection.settings_dict["PASSWORD"] = primary_settings_dict["PASSWORD"]
self.connection.settings_dict['USER'] = primary_settings_dict['USER']
self.connection.settings_dict['PASSWORD'] = primary_settings_dict['PASSWORD']
def _handle_objects_preventing_db_destruction(
self, cursor, parameters, verbosity, autoclobber
):
def _handle_objects_preventing_db_destruction(self, cursor, parameters, verbosity, autoclobber):
# There are objects in the test tablespace which prevent dropping it
# The easy fix is to drop the test user -- but are we allowed to do so?
self.log(
"There are objects in the old test database which prevent its destruction."
"\nIf they belong to the test user, deleting the user will allow the test "
"database to be recreated.\n"
"Otherwise, you will need to find and remove each of these objects, "
"or use a different tablespace.\n"
'There are objects in the old test database which prevent its destruction.\n'
'If they belong to the test user, deleting the user will allow the test '
'database to be recreated.\n'
'Otherwise, you will need to find and remove each of these objects, '
'or use a different tablespace.\n'
)
if self._test_user_create():
if not autoclobber:
confirm = input("Type 'yes' to delete user %s: " % parameters["user"])
if autoclobber or confirm == "yes":
confirm = input("Type 'yes' to delete user %s: " % parameters['user'])
if autoclobber or confirm == 'yes':
try:
if verbosity >= 1:
self.log("Destroying old test user...")
self.log('Destroying old test user...')
self._destroy_test_user(cursor, parameters, verbosity)
except Exception as e:
self.log("Got an error destroying the test user: %s" % e)
self.log('Got an error destroying the test user: %s' % e)
sys.exit(2)
try:
if verbosity >= 1:
self.log(
"Destroying old test database for alias '%s'..."
% self.connection.alias
)
self.log("Destroying old test database for alias '%s'..." % self.connection.alias)
self._execute_test_db_destruction(cursor, parameters, verbosity)
except Exception as e:
self.log("Got an error destroying the test database: %s" % e)
self.log('Got an error destroying the test database: %s' % e)
sys.exit(2)
else:
self.log("Tests cancelled -- test database cannot be recreated.")
self.log('Tests cancelled -- test database cannot be recreated.')
sys.exit(1)
else:
self.log(
"Django is configured to use pre-existing test user '%s',"
" and will not attempt to delete it." % parameters["user"]
)
self.log("Tests cancelled -- test database cannot be recreated.")
self.log("Django is configured to use pre-existing test user '%s',"
" and will not attempt to delete it." % parameters['user'])
self.log('Tests cancelled -- test database cannot be recreated.')
sys.exit(1)
def _destroy_test_db(self, test_database_name, verbosity=1):
@@ -205,28 +169,24 @@ class DatabaseCreation(BaseDatabaseCreation):
Destroy a test database, prompting the user for confirmation if the
database already exists. Return the name of the test database created.
"""
self.connection.settings_dict["USER"] = self.connection.settings_dict[
"SAVED_USER"
]
self.connection.settings_dict["PASSWORD"] = self.connection.settings_dict[
"SAVED_PASSWORD"
]
self.connection.settings_dict['USER'] = self.connection.settings_dict['SAVED_USER']
self.connection.settings_dict['PASSWORD'] = self.connection.settings_dict['SAVED_PASSWORD']
self.connection.close()
parameters = self._get_test_db_params()
with self._maindb_connection.cursor() as cursor:
if self._test_user_create():
if verbosity >= 1:
self.log("Destroying test user...")
self.log('Destroying test user...')
self._destroy_test_user(cursor, parameters, verbosity)
if self._test_database_create():
if verbosity >= 1:
self.log("Destroying test database tables...")
self.log('Destroying test database tables...')
self._execute_test_db_destruction(cursor, parameters, verbosity)
self._maindb_connection.close()
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
if verbosity >= 2:
self.log("_create_test_db(): dbname = %s" % parameters["user"])
self.log('_create_test_db(): dbname = %s' % parameters['user'])
if self._test_database_oracle_managed_files():
statements = [
"""
@@ -254,14 +214,12 @@ class DatabaseCreation(BaseDatabaseCreation):
""",
]
# Ignore "tablespace already exists" error when keepdb is on.
acceptable_ora_err = "ORA-01543" if keepdb else None
self._execute_allow_fail_statements(
cursor, statements, parameters, verbosity, acceptable_ora_err
)
acceptable_ora_err = 'ORA-01543' if keepdb else None
self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
def _create_test_user(self, cursor, parameters, verbosity, keepdb=False):
if verbosity >= 2:
self.log("_create_test_user(): username = %s" % parameters["user"])
self.log('_create_test_user(): username = %s' % parameters['user'])
statements = [
"""CREATE USER %(user)s
IDENTIFIED BY "%(password)s"
@@ -277,51 +235,40 @@ class DatabaseCreation(BaseDatabaseCreation):
TO %(user)s""",
]
# Ignore "user already exists" error when keepdb is on
acceptable_ora_err = "ORA-01920" if keepdb else None
success = self._execute_allow_fail_statements(
cursor, statements, parameters, verbosity, acceptable_ora_err
)
acceptable_ora_err = 'ORA-01920' if keepdb else None
success = self._execute_allow_fail_statements(cursor, statements, parameters, verbosity, acceptable_ora_err)
# If the password was randomly generated, change the user accordingly.
if not success and self._test_settings_get("PASSWORD") is None:
if not success and self._test_settings_get('PASSWORD') is None:
set_password = 'ALTER USER %(user)s IDENTIFIED BY "%(password)s"'
self._execute_statements(cursor, [set_password], parameters, verbosity)
# Most test suites can be run without "create view" and
# "create materialized view" privileges. But some need it.
for object_type in ("VIEW", "MATERIALIZED VIEW"):
extra = "GRANT CREATE %(object_type)s TO %(user)s"
parameters["object_type"] = object_type
success = self._execute_allow_fail_statements(
cursor, [extra], parameters, verbosity, "ORA-01031"
)
for object_type in ('VIEW', 'MATERIALIZED VIEW'):
extra = 'GRANT CREATE %(object_type)s TO %(user)s'
parameters['object_type'] = object_type
success = self._execute_allow_fail_statements(cursor, [extra], parameters, verbosity, 'ORA-01031')
if not success and verbosity >= 2:
self.log(
"Failed to grant CREATE %s permission to test user. This may be ok."
% object_type
)
self.log('Failed to grant CREATE %s permission to test user. This may be ok.' % object_type)
def _execute_test_db_destruction(self, cursor, parameters, verbosity):
if verbosity >= 2:
self.log("_execute_test_db_destruction(): dbname=%s" % parameters["user"])
self.log('_execute_test_db_destruction(): dbname=%s' % parameters['user'])
statements = [
"DROP TABLESPACE %(tblspace)s "
"INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS",
"DROP TABLESPACE %(tblspace_temp)s "
"INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS",
'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS',
]
self._execute_statements(cursor, statements, parameters, verbosity)
def _destroy_test_user(self, cursor, parameters, verbosity):
if verbosity >= 2:
self.log("_destroy_test_user(): user=%s" % parameters["user"])
self.log("Be patient. This can take some time...")
self.log('_destroy_test_user(): user=%s' % parameters['user'])
self.log('Be patient. This can take some time...')
statements = [
"DROP USER %(user)s CASCADE",
'DROP USER %(user)s CASCADE',
]
self._execute_statements(cursor, statements, parameters, verbosity)
def _execute_statements(
self, cursor, statements, parameters, verbosity, allow_quiet_fail=False
):
def _execute_statements(self, cursor, statements, parameters, verbosity, allow_quiet_fail=False):
for template in statements:
stmt = template % parameters
if verbosity >= 2:
@@ -330,12 +277,10 @@ class DatabaseCreation(BaseDatabaseCreation):
cursor.execute(stmt)
except Exception as err:
if (not allow_quiet_fail) or verbosity >= 2:
self.log("Failed (%s)" % (err))
self.log('Failed (%s)' % (err))
raise
def _execute_allow_fail_statements(
self, cursor, statements, parameters, verbosity, acceptable_ora_err
):
def _execute_allow_fail_statements(self, cursor, statements, parameters, verbosity, acceptable_ora_err):
"""
Execute statements which are allowed to fail silently if the Oracle
error code given by `acceptable_ora_err` is raised. Return True if the
@@ -343,16 +288,8 @@ class DatabaseCreation(BaseDatabaseCreation):
"""
try:
# Statement can fail when acceptable_ora_err is not None
allow_quiet_fail = (
acceptable_ora_err is not None and len(acceptable_ora_err) > 0
)
self._execute_statements(
cursor,
statements,
parameters,
verbosity,
allow_quiet_fail=allow_quiet_fail,
)
allow_quiet_fail = acceptable_ora_err is not None and len(acceptable_ora_err) > 0
self._execute_statements(cursor, statements, parameters, verbosity, allow_quiet_fail=allow_quiet_fail)
return True
except DatabaseError as err:
description = str(err)
@@ -362,19 +299,19 @@ class DatabaseCreation(BaseDatabaseCreation):
def _get_test_db_params(self):
return {
"dbname": self._test_database_name(),
"user": self._test_database_user(),
"password": self._test_database_passwd(),
"tblspace": self._test_database_tblspace(),
"tblspace_temp": self._test_database_tblspace_tmp(),
"datafile": self._test_database_tblspace_datafile(),
"datafile_tmp": self._test_database_tblspace_tmp_datafile(),
"maxsize": self._test_database_tblspace_maxsize(),
"maxsize_tmp": self._test_database_tblspace_tmp_maxsize(),
"size": self._test_database_tblspace_size(),
"size_tmp": self._test_database_tblspace_tmp_size(),
"extsize": self._test_database_tblspace_extsize(),
"extsize_tmp": self._test_database_tblspace_tmp_extsize(),
'dbname': self._test_database_name(),
'user': self._test_database_user(),
'password': self._test_database_passwd(),
'tblspace': self._test_database_tblspace(),
'tblspace_temp': self._test_database_tblspace_tmp(),
'datafile': self._test_database_tblspace_datafile(),
'datafile_tmp': self._test_database_tblspace_tmp_datafile(),
'maxsize': self._test_database_tblspace_maxsize(),
'maxsize_tmp': self._test_database_tblspace_tmp_maxsize(),
'size': self._test_database_tblspace_size(),
'size_tmp': self._test_database_tblspace_tmp_size(),
'extsize': self._test_database_tblspace_extsize(),
'extsize_tmp': self._test_database_tblspace_tmp_extsize(),
}
def _test_settings_get(self, key, default=None, prefixed=None):
@@ -383,67 +320,66 @@ class DatabaseCreation(BaseDatabaseCreation):
prefixed entry from the main settings dict.
"""
settings_dict = self.connection.settings_dict
val = settings_dict["TEST"].get(key, default)
val = settings_dict['TEST'].get(key, default)
if val is None and prefixed:
val = TEST_DATABASE_PREFIX + settings_dict[prefixed]
return val
def _test_database_name(self):
return self._test_settings_get("NAME", prefixed="NAME")
return self._test_settings_get('NAME', prefixed='NAME')
def _test_database_create(self):
return self._test_settings_get("CREATE_DB", default=True)
return self._test_settings_get('CREATE_DB', default=True)
def _test_user_create(self):
return self._test_settings_get("CREATE_USER", default=True)
return self._test_settings_get('CREATE_USER', default=True)
def _test_database_user(self):
return self._test_settings_get("USER", prefixed="USER")
return self._test_settings_get('USER', prefixed='USER')
def _test_database_passwd(self):
password = self._test_settings_get("PASSWORD")
password = self._test_settings_get('PASSWORD')
if password is None and self._test_user_create():
# Oracle passwords are limited to 30 chars and can't contain symbols.
password = get_random_string(30)
return password
def _test_database_tblspace(self):
return self._test_settings_get("TBLSPACE", prefixed="USER")
return self._test_settings_get('TBLSPACE', prefixed='USER')
def _test_database_tblspace_tmp(self):
settings_dict = self.connection.settings_dict
return settings_dict["TEST"].get(
"TBLSPACE_TMP", TEST_DATABASE_PREFIX + settings_dict["USER"] + "_temp"
)
return settings_dict['TEST'].get('TBLSPACE_TMP',
TEST_DATABASE_PREFIX + settings_dict['USER'] + '_temp')
def _test_database_tblspace_datafile(self):
tblspace = "%s.dbf" % self._test_database_tblspace()
return self._test_settings_get("DATAFILE", default=tblspace)
tblspace = '%s.dbf' % self._test_database_tblspace()
return self._test_settings_get('DATAFILE', default=tblspace)
def _test_database_tblspace_tmp_datafile(self):
tblspace = "%s.dbf" % self._test_database_tblspace_tmp()
return self._test_settings_get("DATAFILE_TMP", default=tblspace)
tblspace = '%s.dbf' % self._test_database_tblspace_tmp()
return self._test_settings_get('DATAFILE_TMP', default=tblspace)
def _test_database_tblspace_maxsize(self):
return self._test_settings_get("DATAFILE_MAXSIZE", default="500M")
return self._test_settings_get('DATAFILE_MAXSIZE', default='500M')
def _test_database_tblspace_tmp_maxsize(self):
return self._test_settings_get("DATAFILE_TMP_MAXSIZE", default="500M")
return self._test_settings_get('DATAFILE_TMP_MAXSIZE', default='500M')
def _test_database_tblspace_size(self):
return self._test_settings_get("DATAFILE_SIZE", default="50M")
return self._test_settings_get('DATAFILE_SIZE', default='50M')
def _test_database_tblspace_tmp_size(self):
return self._test_settings_get("DATAFILE_TMP_SIZE", default="50M")
return self._test_settings_get('DATAFILE_TMP_SIZE', default='50M')
def _test_database_tblspace_extsize(self):
return self._test_settings_get("DATAFILE_EXTSIZE", default="25M")
return self._test_settings_get('DATAFILE_EXTSIZE', default='25M')
def _test_database_tblspace_tmp_extsize(self):
return self._test_settings_get("DATAFILE_TMP_EXTSIZE", default="25M")
return self._test_settings_get('DATAFILE_TMP_EXTSIZE', default='25M')
def _test_database_oracle_managed_files(self):
return self._test_settings_get("ORACLE_MANAGED_FILES", default=False)
return self._test_settings_get('ORACLE_MANAGED_FILES', default=False)
def _get_test_db_name(self):
"""
@@ -451,14 +387,14 @@ class DatabaseCreation(BaseDatabaseCreation):
to work. This isn't a great deal in this case because DB names as
handled by Django don't have real counterparts in Oracle.
"""
return self.connection.settings_dict["NAME"]
return self.connection.settings_dict['NAME']
def test_db_signature(self):
settings_dict = self.connection.settings_dict
return (
settings_dict["HOST"],
settings_dict["PORT"],
settings_dict["ENGINE"],
settings_dict["NAME"],
settings_dict['HOST'],
settings_dict['PORT'],
settings_dict['ENGINE'],
settings_dict['NAME'],
self._test_database_user(),
)
@@ -15,7 +15,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
select_for_update_of_column = True
can_return_columns_from_insert = True
supports_subqueries_in_group_by = False
ignores_unnecessary_order_by_in_subqueries = False
supports_transactions = True
supports_timezones = False
has_native_duration_field = True
@@ -32,8 +31,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
requires_literal_defaults = True
closed_cursor_error_class = InterfaceError
bare_select_suffix = " FROM DUAL"
# select for update with limit can be achieved on Oracle, but not with the
# current backend.
# select for update with limit can be achieved on Oracle, but not with the current backend.
supports_select_for_update_with_limit = False
supports_temporal_subtraction = True
# Oracle doesn't ignore quoted identifiers case but the current backend
@@ -68,46 +66,44 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_json_field_contains = False
supports_collation_on_textfield = False
test_collations = {
"ci": "BINARY_CI",
"cs": "BINARY",
"non_default": "SWEDISH_CI",
"swedish_ci": "SWEDISH_CI",
'ci': 'BINARY_CI',
'cs': 'BINARY',
'non_default': 'SWEDISH_CI',
'swedish_ci': 'SWEDISH_CI',
}
test_now_utc_template = "CURRENT_TIMESTAMP AT TIME ZONE 'UTC'"
django_test_skips = {
"Oracle doesn't support SHA224.": {
"db_functions.text.test_sha224.SHA224Tests.test_basic",
"db_functions.text.test_sha224.SHA224Tests.test_transform",
'db_functions.text.test_sha224.SHA224Tests.test_basic',
'db_functions.text.test_sha224.SHA224Tests.test_transform',
},
"Oracle doesn't support bitwise XOR.": {
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor",
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null",
'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor',
'expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_null',
},
"Oracle requires ORDER BY in row_number, ANSI:SQL doesn't.": {
"expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering",
'expressions_window.tests.WindowFunctionTests.test_row_number_no_ordering',
},
"Raises ORA-00600: internal error code.": {
"model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery",
'Raises ORA-00600: internal error code on Oracle 18.': {
'model_fields.test_jsonfield.TestQuerying.test_usage_in_subquery',
},
}
django_test_expected_failures = {
# A bug in Django/cx_Oracle with respect to string handling (#23843).
"annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions",
"annotations.tests.NonAggregateAnnotationTestCase."
"test_custom_functions_can_ref_other_functions",
'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions',
'annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions_can_ref_other_functions',
}
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
"GenericIPAddressField": "CharField",
"PositiveBigIntegerField": "BigIntegerField",
"PositiveIntegerField": "IntegerField",
"PositiveSmallIntegerField": "IntegerField",
"SmallIntegerField": "IntegerField",
"TimeField": "DateTimeField",
'GenericIPAddressField': 'CharField',
'PositiveBigIntegerField': 'BigIntegerField',
'PositiveIntegerField': 'IntegerField',
'PositiveSmallIntegerField': 'IntegerField',
'SmallIntegerField': 'IntegerField',
'TimeField': 'DateTimeField',
}
@cached_property
@@ -120,3 +116,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False
raise
return True
@cached_property
def has_json_object_function(self):
# Oracle < 18 supports JSON_OBJECT() but it's not fully functional.
return self.connection.oracle_version >= (18,)
@@ -2,7 +2,7 @@ from django.db.models import DecimalField, DurationField, Func
class IntervalToSeconds(Func):
function = ""
function = ''
template = """
EXTRACT(day from %(expressions)s) * 86400 +
EXTRACT(hour from %(expressions)s) * 3600 +
@@ -11,16 +11,12 @@ class IntervalToSeconds(Func):
"""
def __init__(self, expression, *, output_field=None, **extra):
super().__init__(
expression, output_field=output_field or DecimalField(), **extra
)
super().__init__(expression, output_field=output_field or DecimalField(), **extra)
class SecondsToInterval(Func):
function = "NUMTODSINTERVAL"
function = 'NUMTODSINTERVAL'
template = "%(function)s(%(expressions)s, 'SECOND')"
def __init__(self, expression, *, output_field=None, **extra):
super().__init__(
expression, output_field=output_field or DurationField(), **extra
)
super().__init__(expression, output_field=output_field or DurationField(), **extra)
@@ -3,12 +3,12 @@ from collections import namedtuple
import cx_Oracle
from django.db import models
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
from django.utils.functional import cached_property
FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield", "is_json"))
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('is_autofield', 'is_json'))
class DatabaseIntrospection(BaseDatabaseIntrospection):
@@ -19,33 +19,33 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def data_types_reverse(self):
if self.connection.cx_oracle_version < (8,):
return {
cx_Oracle.BLOB: "BinaryField",
cx_Oracle.CLOB: "TextField",
cx_Oracle.DATETIME: "DateField",
cx_Oracle.FIXED_CHAR: "CharField",
cx_Oracle.FIXED_NCHAR: "CharField",
cx_Oracle.INTERVAL: "DurationField",
cx_Oracle.NATIVE_FLOAT: "FloatField",
cx_Oracle.NCHAR: "CharField",
cx_Oracle.NCLOB: "TextField",
cx_Oracle.NUMBER: "DecimalField",
cx_Oracle.STRING: "CharField",
cx_Oracle.TIMESTAMP: "DateTimeField",
cx_Oracle.BLOB: 'BinaryField',
cx_Oracle.CLOB: 'TextField',
cx_Oracle.DATETIME: 'DateField',
cx_Oracle.FIXED_CHAR: 'CharField',
cx_Oracle.FIXED_NCHAR: 'CharField',
cx_Oracle.INTERVAL: 'DurationField',
cx_Oracle.NATIVE_FLOAT: 'FloatField',
cx_Oracle.NCHAR: 'CharField',
cx_Oracle.NCLOB: 'TextField',
cx_Oracle.NUMBER: 'DecimalField',
cx_Oracle.STRING: 'CharField',
cx_Oracle.TIMESTAMP: 'DateTimeField',
}
else:
return {
cx_Oracle.DB_TYPE_DATE: "DateField",
cx_Oracle.DB_TYPE_BINARY_DOUBLE: "FloatField",
cx_Oracle.DB_TYPE_BLOB: "BinaryField",
cx_Oracle.DB_TYPE_CHAR: "CharField",
cx_Oracle.DB_TYPE_CLOB: "TextField",
cx_Oracle.DB_TYPE_INTERVAL_DS: "DurationField",
cx_Oracle.DB_TYPE_NCHAR: "CharField",
cx_Oracle.DB_TYPE_NCLOB: "TextField",
cx_Oracle.DB_TYPE_NVARCHAR: "CharField",
cx_Oracle.DB_TYPE_NUMBER: "DecimalField",
cx_Oracle.DB_TYPE_TIMESTAMP: "DateTimeField",
cx_Oracle.DB_TYPE_VARCHAR: "CharField",
cx_Oracle.DB_TYPE_DATE: 'DateField',
cx_Oracle.DB_TYPE_BINARY_DOUBLE: 'FloatField',
cx_Oracle.DB_TYPE_BLOB: 'BinaryField',
cx_Oracle.DB_TYPE_CHAR: 'CharField',
cx_Oracle.DB_TYPE_CLOB: 'TextField',
cx_Oracle.DB_TYPE_INTERVAL_DS: 'DurationField',
cx_Oracle.DB_TYPE_NCHAR: 'CharField',
cx_Oracle.DB_TYPE_NCLOB: 'TextField',
cx_Oracle.DB_TYPE_NVARCHAR: 'CharField',
cx_Oracle.DB_TYPE_NUMBER: 'DecimalField',
cx_Oracle.DB_TYPE_TIMESTAMP: 'DateTimeField',
cx_Oracle.DB_TYPE_VARCHAR: 'CharField',
}
def get_field_type(self, data_type, description):
@@ -53,30 +53,25 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
precision, scale = description[4:6]
if scale == 0:
if precision > 11:
return (
"BigAutoField"
if description.is_autofield
else "BigIntegerField"
)
return 'BigAutoField' if description.is_autofield else 'BigIntegerField'
elif 1 < precision < 6 and description.is_autofield:
return "SmallAutoField"
return 'SmallAutoField'
elif precision == 1:
return "BooleanField"
return 'BooleanField'
elif description.is_autofield:
return "AutoField"
return 'AutoField'
else:
return "IntegerField"
return 'IntegerField'
elif scale == -127:
return "FloatField"
return 'FloatField'
elif data_type == cx_Oracle.NCLOB and description.is_json:
return "JSONField"
return 'JSONField'
return super().get_field_type(data_type, description)
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute(
"""
cursor.execute("""
SELECT table_name, 't'
FROM user_tables
WHERE
@@ -89,12 +84,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
SELECT view_name, 'v' FROM user_views
UNION ALL
SELECT mview_name, 'v' FROM user_mviews
"""
)
return [
TableInfo(self.identifier_converter(row[0]), row[1])
for row in cursor.fetchall()
]
""")
return [TableInfo(self.identifier_converter(row[0]), row[1]) for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
"""
@@ -102,8 +93,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
interface.
"""
# user_tab_columns gives data default for columns
cursor.execute(
"""
cursor.execute("""
SELECT
user_tab_cols.column_name,
user_tab_cols.data_default,
@@ -136,51 +126,24 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
LEFT OUTER JOIN
user_tables ON user_tables.table_name = user_tab_cols.table_name
WHERE user_tab_cols.table_name = UPPER(%s)
""",
[table_name],
)
""", [table_name])
field_map = {
column: (
internal_size,
default if default != "NULL" else None,
collation,
is_autofield,
is_json,
)
for (
column,
default,
collation,
internal_size,
is_autofield,
is_json,
) in cursor.fetchall()
column: (internal_size, default if default != 'NULL' else None, collation, is_autofield, is_json)
for column, default, collation, internal_size, is_autofield, is_json in cursor.fetchall()
}
self.cache_bust_counter += 1
cursor.execute(
"SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format(
self.connection.ops.quote_name(table_name), self.cache_bust_counter
)
)
cursor.execute("SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format(
self.connection.ops.quote_name(table_name),
self.cache_bust_counter))
description = []
for desc in cursor.description:
name = desc[0]
internal_size, default, collation, is_autofield, is_json = field_map[name]
name = name % {} # cx_Oracle, for some reason, doubles percent signs.
description.append(
FieldInfo(
self.identifier_converter(name),
*desc[1:3],
internal_size,
desc[4] or 0,
desc[5] or 0,
*desc[6:],
default,
collation,
is_autofield,
is_json,
)
)
description.append(FieldInfo(
self.identifier_converter(name), *desc[1:3], internal_size, desc[4] or 0,
desc[5] or 0, *desc[6:], default, collation, is_autofield, is_json,
))
return description
def identifier_converter(self, name):
@@ -188,8 +151,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
return name.lower()
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute(
"""
cursor.execute("""
SELECT
user_tab_identity_cols.sequence_name,
user_tab_identity_cols.column_name
@@ -203,24 +165,20 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
AND cols.column_name = user_tab_identity_cols.column_name
AND user_constraints.constraint_type = 'P'
AND user_tab_identity_cols.table_name = UPPER(%s)
""",
[table_name],
)
""", [table_name])
# Oracle allows only one identity column per table.
row = cursor.fetchone()
if row:
return [
{
"name": self.identifier_converter(row[0]),
"table": self.identifier_converter(table_name),
"column": self.identifier_converter(row[1]),
}
]
return [{
'name': self.identifier_converter(row[0]),
'table': self.identifier_converter(table_name),
'column': self.identifier_converter(row[1]),
}]
# To keep backward compatibility for AutoFields that aren't Oracle
# identity columns.
for f in table_fields:
if isinstance(f, models.AutoField):
return [{"table": table_name, "column": f.column}]
return [{'table': table_name, 'column': f.column}]
return []
def get_relations(self, cursor, table_name):
@@ -229,48 +187,37 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
representing all relationships to the given table.
"""
table_name = table_name.upper()
cursor.execute(
"""
cursor.execute("""
SELECT ca.column_name, cb.table_name, cb.column_name
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb
WHERE user_constraints.table_name = %s AND
user_constraints.constraint_name = ca.constraint_name AND
user_constraints.r_constraint_name = cb.constraint_name AND
ca.position = cb.position""",
[table_name],
)
ca.position = cb.position""", [table_name])
return {
self.identifier_converter(field_name): (
self.identifier_converter(rel_field_name),
self.identifier_converter(rel_table_name),
)
for field_name, rel_table_name, rel_field_name in cursor.fetchall()
) for field_name, rel_table_name, rel_field_name in cursor.fetchall()
}
def get_key_columns(self, cursor, table_name):
cursor.execute(
"""
SELECT
ccol.column_name,
rcol.table_name AS referenced_table,
rcol.column_name AS referenced_column
cursor.execute("""
SELECT ccol.column_name, rcol.table_name AS referenced_table, rcol.column_name AS referenced_column
FROM user_constraints c
JOIN user_cons_columns ccol
ON ccol.constraint_name = c.constraint_name
JOIN user_cons_columns rcol
ON rcol.constraint_name = c.r_constraint_name
WHERE c.table_name = %s AND c.constraint_type = 'R'""",
[table_name.upper()],
)
WHERE c.table_name = %s AND c.constraint_type = 'R'""", [table_name.upper()])
return [
tuple(self.identifier_converter(cell) for cell in row)
for row in cursor.fetchall()
]
def get_primary_key_column(self, cursor, table_name):
cursor.execute(
"""
cursor.execute("""
SELECT
cols.column_name
FROM
@@ -281,9 +228,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
user_constraints.constraint_type = 'P' AND
user_constraints.table_name = UPPER(%s) AND
cols.position = 1
""",
[table_name],
)
""", [table_name])
row = cursor.fetchone()
return self.identifier_converter(row[0]) if row else None
@@ -294,12 +239,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"""
constraints = {}
# Loop over the constraints, getting PKs, uniques, and checks
cursor.execute(
"""
cursor.execute("""
SELECT
user_constraints.constraint_name,
LISTAGG(LOWER(cols.column_name), ',')
WITHIN GROUP (ORDER BY cols.position),
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.position),
CASE user_constraints.constraint_type
WHEN 'P' THEN 1
ELSE 0
@@ -315,68 +258,56 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
FROM
user_constraints
LEFT OUTER JOIN
user_cons_columns cols
ON user_constraints.constraint_name = cols.constraint_name
user_cons_columns cols ON user_constraints.constraint_name = cols.constraint_name
WHERE
user_constraints.constraint_type = ANY('P', 'U', 'C')
AND user_constraints.table_name = UPPER(%s)
GROUP BY user_constraints.constraint_name, user_constraints.constraint_type
""",
[table_name],
)
""", [table_name])
for constraint, columns, pk, unique, check in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
"columns": columns.split(","),
"primary_key": pk,
"unique": unique,
"foreign_key": None,
"check": check,
"index": unique, # All uniques come with an index
'columns': columns.split(','),
'primary_key': pk,
'unique': unique,
'foreign_key': None,
'check': check,
'index': unique, # All uniques come with an index
}
# Foreign key constraints
cursor.execute(
"""
cursor.execute("""
SELECT
cons.constraint_name,
LISTAGG(LOWER(cols.column_name), ',')
WITHIN GROUP (ORDER BY cols.position),
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.position),
LOWER(rcols.table_name),
LOWER(rcols.column_name)
FROM
user_constraints cons
INNER JOIN
user_cons_columns rcols
ON rcols.constraint_name = cons.r_constraint_name AND rcols.position = 1
user_cons_columns rcols ON rcols.constraint_name = cons.r_constraint_name AND rcols.position = 1
LEFT OUTER JOIN
user_cons_columns cols
ON cons.constraint_name = cols.constraint_name
user_cons_columns cols ON cons.constraint_name = cols.constraint_name
WHERE
cons.constraint_type = 'R' AND
cons.table_name = UPPER(%s)
GROUP BY cons.constraint_name, rcols.table_name, rcols.column_name
""",
[table_name],
)
""", [table_name])
for constraint, columns, other_table, other_column in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
"primary_key": False,
"unique": False,
"foreign_key": (other_table, other_column),
"check": False,
"index": False,
"columns": columns.split(","),
'primary_key': False,
'unique': False,
'foreign_key': (other_table, other_column),
'check': False,
'index': False,
'columns': columns.split(','),
}
# Now get indexes
cursor.execute(
"""
cursor.execute("""
SELECT
ind.index_name,
LOWER(ind.index_type),
LOWER(ind.uniqueness),
LISTAGG(LOWER(cols.column_name), ',')
WITHIN GROUP (ORDER BY cols.column_position),
LISTAGG(LOWER(cols.column_name), ',') WITHIN GROUP (ORDER BY cols.column_position),
LISTAGG(cols.descend, ',') WITHIN GROUP (ORDER BY cols.column_position)
FROM
user_ind_columns cols, user_indexes ind
@@ -387,20 +318,18 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
FROM user_constraints cons
WHERE ind.index_name = cons.index_name
) AND cols.index_name = ind.index_name
GROUP BY ind.index_name, ind.index_type, ind.uniqueness
""",
[table_name],
)
for constraint, type_, unique, columns, orders in cursor.fetchall():
GROUP BY ind.index_name, ind.index_type
""", [table_name])
for constraint, type_, columns, orders in cursor.fetchall():
constraint = self.identifier_converter(constraint)
constraints[constraint] = {
"primary_key": False,
"unique": unique == "unique",
"foreign_key": None,
"check": False,
"index": True,
"type": "idx" if type_ == "normal" else type_,
"columns": columns.split(","),
"orders": orders.split(","),
'primary_key': False,
'unique': False,
'foreign_key': None,
'check': False,
'index': True,
'type': 'idx' if type_ == 'normal' else type_,
'columns': columns.split(','),
'orders': orders.split(','),
}
return constraints
@@ -5,8 +5,8 @@ from functools import lru_cache
from django.conf import settings
from django.db import DatabaseError, NotSupportedError
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
from django.db.backends.utils import strip_quotes, truncate_name
from django.db.models import AutoField, Exists, ExpressionWrapper
from django.db.models.expressions import RawSQL
from django.db.models.sql.where import WhereNode
from django.utils import timezone
@@ -23,17 +23,17 @@ class DatabaseOperations(BaseDatabaseOperations):
# SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by
# SmallAutoField, to preserve backward compatibility.
integer_field_ranges = {
"SmallIntegerField": (-99999999999, 99999999999),
"IntegerField": (-99999999999, 99999999999),
"BigIntegerField": (-9999999999999999999, 9999999999999999999),
"PositiveBigIntegerField": (0, 9999999999999999999),
"PositiveSmallIntegerField": (0, 99999999999),
"PositiveIntegerField": (0, 99999999999),
"SmallAutoField": (-99999, 99999),
"AutoField": (-99999999999, 99999999999),
"BigAutoField": (-9999999999999999999, 9999999999999999999),
'SmallIntegerField': (-99999999999, 99999999999),
'IntegerField': (-99999999999, 99999999999),
'BigIntegerField': (-9999999999999999999, 9999999999999999999),
'PositiveBigIntegerField': (0, 9999999999999999999),
'PositiveSmallIntegerField': (0, 99999999999),
'PositiveIntegerField': (0, 99999999999),
'SmallAutoField': (-99999, 99999),
'AutoField': (-99999999999, 99999999999),
'BigAutoField': (-9999999999999999999, 9999999999999999999),
}
set_operators = {**BaseDatabaseOperations.set_operators, "difference": "MINUS"}
set_operators = {**BaseDatabaseOperations.set_operators, 'difference': 'MINUS'}
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
_sequence_reset_sql = """
@@ -61,45 +61,42 @@ END;
/"""
# Oracle doesn't support string without precision; use the max string size.
cast_char_field_without_max_length = "NVARCHAR2(2000)"
cast_char_field_without_max_length = 'NVARCHAR2(2000)'
cast_data_types = {
"AutoField": "NUMBER(11)",
"BigAutoField": "NUMBER(19)",
"SmallAutoField": "NUMBER(5)",
"TextField": cast_char_field_without_max_length,
'AutoField': 'NUMBER(11)',
'BigAutoField': 'NUMBER(19)',
'SmallAutoField': 'NUMBER(5)',
'TextField': cast_char_field_without_max_length,
}
def cache_key_culling_sql(self):
return (
"SELECT cache_key FROM %s "
"ORDER BY cache_key OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY"
)
return 'SELECT cache_key FROM %s ORDER BY cache_key OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY'
def date_extract_sql(self, lookup_type, field_name):
if lookup_type == "week_day":
if lookup_type == 'week_day':
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
return "TO_CHAR(%s, 'D')" % field_name
elif lookup_type == "iso_week_day":
elif lookup_type == 'iso_week_day':
return "TO_CHAR(%s - 1, 'D')" % field_name
elif lookup_type == "week":
elif lookup_type == 'week':
# IW = ISO week number
return "TO_CHAR(%s, 'IW')" % field_name
elif lookup_type == "quarter":
elif lookup_type == 'quarter':
return "TO_CHAR(%s, 'Q')" % field_name
elif lookup_type == "iso_year":
elif lookup_type == 'iso_year':
return "TO_CHAR(%s, 'IYYY')" % field_name
else:
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/EXTRACT-datetime.html
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/EXTRACT-datetime.html
return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ("year", "month"):
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ('year', 'month'):
return "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
elif lookup_type == "quarter":
elif lookup_type == 'quarter':
return "TRUNC(%s, 'Q')" % field_name
elif lookup_type == "week":
elif lookup_type == 'week':
return "TRUNC(%s, 'IW')" % field_name
else:
return "TRUNC(%s)" % field_name
@@ -108,11 +105,14 @@ END;
# if the time zone name is passed in parameter. Use interpolation instead.
# https://groups.google.com/forum/#!msg/django-developers/zwQju7hbG78/9l934yelwfsJ
# This regexp matches all time zone names from the zoneinfo database.
_tzname_re = _lazy_re_compile(r"^[\w/:+-]+$")
_tzname_re = _lazy_re_compile(r'^[\w/:+-]+$')
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
return f"{sign}{offset}" if offset else tzname
if '+' in tzname:
return tzname[tzname.find('+'):]
elif '-' in tzname:
return tzname[tzname.find('-'):]
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if not (settings.USE_TZ and tzname):
@@ -132,19 +132,12 @@ END;
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "TRUNC(%s)" % field_name
return 'TRUNC(%s)' % field_name
def datetime_cast_time_sql(self, field_name, tzname):
# Since `TimeField` values are stored as TIMESTAMP change to the
# default date and convert the field to the specified timezone.
convert_datetime_sql = (
"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR(%s, 'HH24:MI:SS.FF')), "
"'YYYY-MM-DD HH24:MI:SS.FF')"
) % self._convert_field_to_tz(field_name, tzname)
return "CASE WHEN %s IS NOT NULL THEN %s ELSE NULL END" % (
field_name,
convert_datetime_sql,
)
# Since `TimeField` values are stored as TIMESTAMP where only the date
# part is ignored, convert the field to the specified timezone.
return self._convert_field_to_tz(field_name, tzname)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
@@ -152,23 +145,21 @@ END;
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ("year", "month"):
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf/ROUND-and-TRUNC-Date-Functions.html
if lookup_type in ('year', 'month'):
sql = "TRUNC(%s, '%s')" % (field_name, lookup_type.upper())
elif lookup_type == "quarter":
elif lookup_type == 'quarter':
sql = "TRUNC(%s, 'Q')" % field_name
elif lookup_type == "week":
elif lookup_type == 'week':
sql = "TRUNC(%s, 'IW')" % field_name
elif lookup_type == "day":
elif lookup_type == 'day':
sql = "TRUNC(%s)" % field_name
elif lookup_type == "hour":
elif lookup_type == 'hour':
sql = "TRUNC(%s, 'HH24')" % field_name
elif lookup_type == "minute":
elif lookup_type == 'minute':
sql = "TRUNC(%s, 'MI')" % field_name
else:
sql = (
"CAST(%s AS DATE)" % field_name
) # Cast to DATE removes sub-second precision.
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
return sql
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
@@ -176,42 +167,40 @@ END;
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
# the date part of the later is ignored.
field_name = self._convert_field_to_tz(field_name, tzname)
if lookup_type == "hour":
if lookup_type == 'hour':
sql = "TRUNC(%s, 'HH24')" % field_name
elif lookup_type == "minute":
elif lookup_type == 'minute':
sql = "TRUNC(%s, 'MI')" % field_name
elif lookup_type == "second":
sql = (
"CAST(%s AS DATE)" % field_name
) # Cast to DATE removes sub-second precision.
elif lookup_type == 'second':
sql = "CAST(%s AS DATE)" % field_name # Cast to DATE removes sub-second precision.
return sql
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type in ["JSONField", "TextField"]:
if internal_type in ['JSONField', 'TextField']:
converters.append(self.convert_textfield_value)
elif internal_type == "BinaryField":
elif internal_type == 'BinaryField':
converters.append(self.convert_binaryfield_value)
elif internal_type == "BooleanField":
elif internal_type in ['BooleanField', 'NullBooleanField']:
converters.append(self.convert_booleanfield_value)
elif internal_type == "DateTimeField":
elif internal_type == 'DateTimeField':
if settings.USE_TZ:
converters.append(self.convert_datetimefield_value)
elif internal_type == "DateField":
elif internal_type == 'DateField':
converters.append(self.convert_datefield_value)
elif internal_type == "TimeField":
elif internal_type == 'TimeField':
converters.append(self.convert_timefield_value)
elif internal_type == "UUIDField":
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
# Oracle stores empty strings as null. If the field accepts the empty
# string, undo this to adhere to the Django convention of using
# the empty string instead of null.
if expression.output_field.empty_strings_allowed:
if expression.field.empty_strings_allowed:
converters.append(
self.convert_empty_bytes
if internal_type == "BinaryField"
else self.convert_empty_string
if internal_type == 'BinaryField' else
self.convert_empty_string
)
return converters
@@ -256,11 +245,11 @@ END;
@staticmethod
def convert_empty_string(value, expression, connection):
return "" if value is None else value
return '' if value is None else value
@staticmethod
def convert_empty_bytes(value, expression, connection):
return b"" if value is None else value
return b'' if value is None else value
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
@@ -269,18 +258,20 @@ END;
columns = []
for param in returning_params:
value = param.get_value()
if value == []:
if value is None or value == []:
# cx_Oracle < 6.3 returns None, >= 6.3 returns empty list.
raise DatabaseError(
"The database did not return a new row id. Probably "
'The database did not return a new row id. Probably '
'"ORA-1403: no data found" was raised internally but was '
"hidden by the Oracle OCI library (see "
"https://code.djangoproject.com/ticket/28859)."
'hidden by the Oracle OCI library (see '
'https://code.djangoproject.com/ticket/28859).'
)
columns.append(value[0])
# cx_Oracle < 7 returns value, >= 7 returns list with single value.
columns.append(value[0] if isinstance(value, list) else value)
return tuple(columns)
def field_cast_sql(self, db_type, internal_type):
if db_type and db_type.endswith("LOB") and internal_type != "JSONField":
if db_type and db_type.endswith('LOB') and internal_type != 'JSONField':
return "DBMS_LOB.SUBSTR(%s)"
else:
return "%s"
@@ -290,14 +281,10 @@ END;
def limit_offset_sql(self, low_mark, high_mark):
fetch, offset = self._get_limit_offset_params(low_mark, high_mark)
return " ".join(
sql
for sql in (
("OFFSET %d ROWS" % offset) if offset else None,
("FETCH FIRST %d ROWS ONLY" % fetch) if fetch else None,
)
if sql
)
return ' '.join(sql for sql in (
('OFFSET %d ROWS' % offset) if offset else None,
('FETCH FIRST %d ROWS ONLY' % fetch) if fetch else None,
) if sql)
def last_executed_query(self, cursor, sql, params):
# https://cx-oracle.readthedocs.io/en/latest/cursor.html#Cursor.statement
@@ -308,14 +295,10 @@ END;
# parameters manually.
if isinstance(params, (tuple, list)):
for i, param in enumerate(params):
statement = statement.replace(
":arg%d" % i, force_str(param, errors="replace")
)
statement = statement.replace(':arg%d' % i, force_str(param, errors='replace'))
elif isinstance(params, dict):
for key, param in params.items():
statement = statement.replace(
":%s" % key, force_str(param, errors="replace")
)
statement = statement.replace(':%s' % key, force_str(param, errors='replace'))
return statement
def last_insert_id(self, cursor, table_name, pk_name):
@@ -324,10 +307,10 @@ END;
return cursor.fetchone()[0]
def lookup_cast(self, lookup_type, internal_type=None):
if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
return "UPPER(%s)"
if internal_type == "JSONField" and lookup_type == "exact":
return "DBMS_LOB.SUBSTR(%s)"
if internal_type == 'JSONField' and lookup_type == 'exact':
return 'DBMS_LOB.SUBSTR(%s)'
return "%s"
def max_in_list_size(self):
@@ -344,7 +327,7 @@ END;
def process_clob(self, value):
if value is None:
return ""
return ''
return value.read()
def quote_name(self, name):
@@ -353,69 +336,59 @@ END;
# always defaults to uppercase.
# We simplify things by making Oracle identifiers always uppercase.
if not name.startswith('"') and not name.endswith('"'):
name = '"%s"' % truncate_name(name, self.max_name_length())
name = '"%s"' % truncate_name(name.upper(), self.max_name_length())
# Oracle puts the query text into a (query % args) construct, so % signs
# in names need to be escaped. The '%%' will be collapsed back to '%' at
# that stage so we aren't really making the name longer here.
name = name.replace("%", "%%")
name = name.replace('%', '%%')
return name.upper()
def regex_lookup(self, lookup_type):
if lookup_type == "regex":
if lookup_type == 'regex':
match_option = "'c'"
else:
match_option = "'i'"
return "REGEXP_LIKE(%%s, %%s, %s)" % match_option
return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
def return_insert_columns(self, fields):
if not fields:
return "", ()
return '', ()
field_names = []
params = []
for field in fields:
field_names.append(
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
)
field_names.append('%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
))
params.append(InsertVar(field))
return "RETURNING %s INTO %s" % (
", ".join(field_names),
", ".join(["%s"] * len(params)),
return 'RETURNING %s INTO %s' % (
', '.join(field_names),
', '.join(['%s'] * len(params)),
), tuple(params)
def __foreign_key_constraints(self, table_name, recursive):
with self.connection.cursor() as cursor:
if recursive:
cursor.execute(
"""
cursor.execute("""
SELECT
user_tables.table_name, rcons.constraint_name
FROM
user_tables
JOIN
user_constraints cons
ON (user_tables.table_name = cons.table_name
AND cons.constraint_type = ANY('P', 'U'))
ON (user_tables.table_name = cons.table_name AND cons.constraint_type = ANY('P', 'U'))
LEFT JOIN
user_constraints rcons
ON (user_tables.table_name = rcons.table_name
AND rcons.constraint_type = 'R')
ON (user_tables.table_name = rcons.table_name AND rcons.constraint_type = 'R')
START WITH user_tables.table_name = UPPER(%s)
CONNECT BY
NOCYCLE PRIOR cons.constraint_name = rcons.r_constraint_name
CONNECT BY NOCYCLE PRIOR cons.constraint_name = rcons.r_constraint_name
GROUP BY
user_tables.table_name, rcons.constraint_name
HAVING user_tables.table_name != UPPER(%s)
ORDER BY MAX(level) DESC
""",
(table_name, table_name),
)
""", (table_name, table_name))
else:
cursor.execute(
"""
cursor.execute("""
SELECT
cons.table_name, cons.constraint_name
FROM
@@ -423,9 +396,7 @@ END;
WHERE
cons.constraint_type = 'R'
AND cons.table_name = UPPER(%s)
""",
(table_name,),
)
""", (table_name,))
return cursor.fetchall()
@cached_property
@@ -445,54 +416,42 @@ END;
# which truncates all dependent tables by manually retrieving all
# foreign key constraints and resolving dependencies.
for table in tables:
for foreign_table, constraint in self._foreign_key_constraints(
table, recursive=allow_cascade
):
for foreign_table, constraint in self._foreign_key_constraints(table, recursive=allow_cascade):
if allow_cascade:
truncated_tables.add(foreign_table)
constraints.add((foreign_table, constraint))
sql = (
[
"%s %s %s %s %s %s %s %s;"
% (
style.SQL_KEYWORD("ALTER"),
style.SQL_KEYWORD("TABLE"),
style.SQL_FIELD(self.quote_name(table)),
style.SQL_KEYWORD("DISABLE"),
style.SQL_KEYWORD("CONSTRAINT"),
style.SQL_FIELD(self.quote_name(constraint)),
style.SQL_KEYWORD("KEEP"),
style.SQL_KEYWORD("INDEX"),
)
for table, constraint in constraints
]
+ [
"%s %s %s;"
% (
style.SQL_KEYWORD("TRUNCATE"),
style.SQL_KEYWORD("TABLE"),
style.SQL_FIELD(self.quote_name(table)),
)
for table in truncated_tables
]
+ [
"%s %s %s %s %s %s;"
% (
style.SQL_KEYWORD("ALTER"),
style.SQL_KEYWORD("TABLE"),
style.SQL_FIELD(self.quote_name(table)),
style.SQL_KEYWORD("ENABLE"),
style.SQL_KEYWORD("CONSTRAINT"),
style.SQL_FIELD(self.quote_name(constraint)),
)
for table, constraint in constraints
]
)
sql = [
'%s %s %s %s %s %s %s %s;' % (
style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)),
style.SQL_KEYWORD('DISABLE'),
style.SQL_KEYWORD('CONSTRAINT'),
style.SQL_FIELD(self.quote_name(constraint)),
style.SQL_KEYWORD('KEEP'),
style.SQL_KEYWORD('INDEX'),
) for table, constraint in constraints
] + [
'%s %s %s;' % (
style.SQL_KEYWORD('TRUNCATE'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)),
) for table in truncated_tables
] + [
'%s %s %s %s %s %s;' % (
style.SQL_KEYWORD('ALTER'),
style.SQL_KEYWORD('TABLE'),
style.SQL_FIELD(self.quote_name(table)),
style.SQL_KEYWORD('ENABLE'),
style.SQL_KEYWORD('CONSTRAINT'),
style.SQL_FIELD(self.quote_name(constraint)),
) for table, constraint in constraints
]
if reset_sequences:
sequences = [
sequence
for sequence in self.connection.introspection.sequence_list()
if sequence["table"].upper() in truncated_tables
if sequence['table'].upper() in truncated_tables
]
# Since we've just deleted all the rows, running our sequence ALTER
# code will reset the sequence to 0.
@@ -502,17 +461,15 @@ END;
def sequence_reset_by_name_sql(self, style, sequences):
sql = []
for sequence_info in sequences:
no_autofield_sequence_name = self._get_no_autofield_sequence_name(
sequence_info["table"]
)
table = self.quote_name(sequence_info["table"])
column = self.quote_name(sequence_info["column"] or "id")
no_autofield_sequence_name = self._get_no_autofield_sequence_name(sequence_info['table'])
table = self.quote_name(sequence_info['table'])
column = self.quote_name(sequence_info['column'] or 'id')
query = self._sequence_reset_sql % {
"no_autofield_sequence_name": no_autofield_sequence_name,
"table": table,
"column": column,
"table_name": strip_quotes(table),
"column_name": strip_quotes(column),
'no_autofield_sequence_name': no_autofield_sequence_name,
'table': table,
'column': column,
'table_name': strip_quotes(table),
'column_name': strip_quotes(column),
}
sql.append(query)
return sql
@@ -523,28 +480,23 @@ END;
for model in model_list:
for f in model._meta.local_fields:
if isinstance(f, AutoField):
no_autofield_sequence_name = self._get_no_autofield_sequence_name(
model._meta.db_table
)
no_autofield_sequence_name = self._get_no_autofield_sequence_name(model._meta.db_table)
table = self.quote_name(model._meta.db_table)
column = self.quote_name(f.column)
output.append(
query
% {
"no_autofield_sequence_name": no_autofield_sequence_name,
"table": table,
"column": column,
"table_name": strip_quotes(table),
"column_name": strip_quotes(column),
}
)
output.append(query % {
'no_autofield_sequence_name': no_autofield_sequence_name,
'table': table,
'column': column,
'table_name': strip_quotes(table),
'column_name': strip_quotes(column),
})
# Only one AutoField is allowed per model, so don't
# continue to loop
break
return output
def start_transaction_sql(self):
return ""
return ''
def tablespace_sql(self, tablespace, inline=False):
if inline:
@@ -575,7 +527,7 @@ END;
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
if hasattr(value, 'resolve_expression'):
return value
# cx_Oracle doesn't support tz-aware datetimes
@@ -583,10 +535,7 @@ END;
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError(
"Oracle backend does not support timezone-aware datetimes when "
"USE_TZ is False."
)
raise ValueError("Oracle backend does not support timezone-aware datetimes when USE_TZ is False.")
return Oracle_datetime.from_datetime(value)
@@ -595,39 +544,38 @@ END;
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
if hasattr(value, 'resolve_expression'):
return value
if isinstance(value, str):
return datetime.datetime.strptime(value, "%H:%M:%S")
return datetime.datetime.strptime(value, '%H:%M:%S')
# Oracle doesn't support tz-aware times
if timezone.is_aware(value):
raise ValueError("Oracle backend does not support timezone-aware times.")
return Oracle_datetime(
1900, 1, 1, value.hour, value.minute, value.second, value.microsecond
)
return Oracle_datetime(1900, 1, 1, value.hour, value.minute,
value.second, value.microsecond)
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
return value
def combine_expression(self, connector, sub_expressions):
lhs, rhs = sub_expressions
if connector == "%%":
return "MOD(%s)" % ",".join(sub_expressions)
elif connector == "&":
return "BITAND(%s)" % ",".join(sub_expressions)
elif connector == "|":
return "BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s" % {"lhs": lhs, "rhs": rhs}
elif connector == "<<":
return "(%(lhs)s * POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
elif connector == ">>":
return "FLOOR(%(lhs)s / POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
elif connector == "^":
return "POWER(%s)" % ",".join(sub_expressions)
elif connector == "#":
raise NotSupportedError("Bitwise XOR is not supported in Oracle.")
if connector == '%%':
return 'MOD(%s)' % ','.join(sub_expressions)
elif connector == '&':
return 'BITAND(%s)' % ','.join(sub_expressions)
elif connector == '|':
return 'BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s' % {'lhs': lhs, 'rhs': rhs}
elif connector == '<<':
return '(%(lhs)s * POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
elif connector == '>>':
return 'FLOOR(%(lhs)s / POWER(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
elif connector == '^':
return 'POWER(%s)' % ','.join(sub_expressions)
elif connector == '#':
raise NotSupportedError('Bitwise XOR is not supported in Oracle.')
return super().combine_expression(connector, sub_expressions)
def _get_no_autofield_sequence_name(self, table):
@@ -636,17 +584,14 @@ END;
AutoFields that aren't Oracle identity columns.
"""
name_length = self.max_name_length() - 3
return "%s_SQ" % truncate_name(strip_quotes(table), name_length).upper()
return '%s_SQ' % truncate_name(strip_quotes(table), name_length).upper()
def _get_sequence_name(self, cursor, table, pk_name):
cursor.execute(
"""
cursor.execute("""
SELECT sequence_name
FROM user_tab_identity_cols
WHERE table_name = UPPER(%s)
AND column_name = UPPER(%s)""",
[table, pk_name],
)
AND column_name = UPPER(%s)""", [table, pk_name])
row = cursor.fetchone()
return self._get_no_autofield_sequence_name(table) if row is None else row[0]
@@ -657,33 +602,26 @@ END;
for i, placeholder in enumerate(row):
# A model without any fields has fields=[None].
if fields[i]:
internal_type = getattr(
fields[i], "target_field", fields[i]
).get_internal_type()
placeholder = (
BulkInsertMapper.types.get(internal_type, "%s") % placeholder
)
internal_type = getattr(fields[i], 'target_field', fields[i]).get_internal_type()
placeholder = BulkInsertMapper.types.get(internal_type, '%s') % placeholder
# Add columns aliases to the first select to avoid "ORA-00918:
# column ambiguously defined" when two or more columns in the
# first select have the same value.
if not query:
placeholder = "%s col_%s" % (placeholder, i)
placeholder = '%s col_%s' % (placeholder, i)
select.append(placeholder)
query.append("SELECT %s FROM DUAL" % ", ".join(select))
query.append('SELECT %s FROM DUAL' % ', '.join(select))
# Bulk insert to tables with Oracle identity columns causes Oracle to
# add sequence.nextval to it. Sequence.nextval cannot be used with the
# UNION operator. To prevent incorrect SQL, move UNION to a subquery.
return "SELECT * FROM (%s)" % " UNION ALL ".join(query)
return 'SELECT * FROM (%s)' % ' UNION ALL '.join(query)
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == "DateField":
if internal_type == 'DateField':
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
return (
"NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql),
params,
)
return "NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql), params
return super().subtract_temporals(internal_type, lhs, rhs)
def bulk_batch_size(self, fields, objs):
@@ -697,12 +635,10 @@ END;
Oracle supports only EXISTS(...) or filters in the WHERE clause, others
must be compared with True.
"""
if isinstance(expression, (Exists, Lookup, WhereNode)):
if isinstance(expression, (Exists, WhereNode)):
return True
if isinstance(expression, ExpressionWrapper) and expression.conditional:
return self.conditional_expression_supported_in_where_clause(
expression.expression
)
return self.conditional_expression_supported_in_where_clause(expression.expression)
if isinstance(expression, RawSQL) and expression.conditional:
return True
return False
@@ -3,10 +3,7 @@ import datetime
import re
from django.db import DatabaseError
from django.db.backends.base.schema import (
BaseDatabaseSchemaEditor,
_related_non_m2m_objects,
)
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
@@ -21,9 +18,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_alter_column_collate = "MODIFY %(column)s %(type)s%(collation)s"
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
sql_create_column_inline_fk = (
"CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
)
sql_create_column_inline_fk = 'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS"
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
@@ -31,7 +26,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)):
return "'%s'" % value
elif isinstance(value, str):
return "'%s'" % value.replace("'", "''").replace("%", "%%")
return "'%s'" % value.replace("\'", "\'\'").replace('%', '%%')
elif isinstance(value, (bytes, bytearray, memoryview)):
return "'%s'" % value.hex()
elif isinstance(value, bool):
@@ -50,8 +45,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Run superclass action
super().delete_model(model)
# Clean up manually created sequence.
self.execute(
"""
self.execute("""
DECLARE
i INTEGER;
BEGIN
@@ -61,13 +55,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"';
END IF;
END;
/"""
% {
"sq_name": self.connection.ops._get_no_autofield_sequence_name(
model._meta.db_table
)
}
)
/""" % {'sq_name': self.connection.ops._get_no_autofield_sequence_name(model._meta.db_table)})
def alter_field(self, model, old_field, new_field, strict=False):
try:
@@ -76,16 +64,16 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
description = str(e)
# If we're changing type to an unsupported type we need a
# SQLite-ish workaround
if "ORA-22858" in description or "ORA-22859" in description:
if 'ORA-22858' in description or 'ORA-22859' in description:
self._alter_field_type_workaround(model, old_field, new_field)
# If an identity column is changing to a non-numeric type, drop the
# identity first.
elif "ORA-30675" in description:
elif 'ORA-30675' in description:
self._drop_identity(model._meta.db_table, old_field.column)
self.alter_field(model, old_field, new_field, strict)
# If a primary key column is changing to an identity column, drop
# the primary key first.
elif "ORA-30673" in description and old_field.primary_key:
elif 'ORA-30673' in description and old_field.primary_key:
self._delete_primary_key(model, strict=True)
self._alter_field_type_workaround(model, old_field, new_field)
else:
@@ -105,66 +93,45 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Make a new field that's like the new one but with a temporary
# column name.
new_temp_field = copy.deepcopy(new_field)
new_temp_field.null = new_field.get_internal_type() not in (
"AutoField",
"BigAutoField",
"SmallAutoField",
)
new_temp_field.null = (new_field.get_internal_type() not in ('AutoField', 'BigAutoField', 'SmallAutoField'))
new_temp_field.column = self._generate_temp_name(new_field.column)
# Add it
self.add_field(model, new_temp_field)
# Explicit data type conversion
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf
# https://docs.oracle.com/en/database/oracle/oracle-database/18/sqlrf
# /Data-Type-Comparison-Rules.html#GUID-D0C5A47E-6F93-4C2D-9E49-4F2B86B359DD
new_value = self.quote_name(old_field.column)
old_type = old_field.db_type(self.connection)
if re.match("^N?CLOB", old_type):
if re.match('^N?CLOB', old_type):
new_value = "TO_CHAR(%s)" % new_value
old_type = "VARCHAR2"
if re.match("^N?VARCHAR2", old_type):
old_type = 'VARCHAR2'
if re.match('^N?VARCHAR2', old_type):
new_internal_type = new_field.get_internal_type()
if new_internal_type == "DateField":
if new_internal_type == 'DateField':
new_value = "TO_DATE(%s, 'YYYY-MM-DD')" % new_value
elif new_internal_type == "DateTimeField":
elif new_internal_type == 'DateTimeField':
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
elif new_internal_type == "TimeField":
elif new_internal_type == 'TimeField':
# TimeField are stored as TIMESTAMP with a 1900-01-01 date part.
new_value = "CONCAT('1900-01-01 ', %s)" % new_value
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
new_value = "TO_TIMESTAMP(CONCAT('1900-01-01 ', %s), 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
# Transfer values across
self.execute(
"UPDATE %s set %s=%s"
% (
self.quote_name(model._meta.db_table),
self.quote_name(new_temp_field.column),
new_value,
)
)
self.execute("UPDATE %s set %s=%s" % (
self.quote_name(model._meta.db_table),
self.quote_name(new_temp_field.column),
new_value,
))
# Drop the old field
self.remove_field(model, old_field)
# Rename and possibly make the new field NOT NULL
super().alter_field(model, new_temp_field, new_field)
# Recreate foreign key (if necessary) because the old field is not
# passed to the alter_field() and data types of new_temp_field and
# new_field always match.
new_type = new_field.db_type(self.connection)
if (
(old_field.primary_key and new_field.primary_key)
or (old_field.unique and new_field.unique)
) and old_type != new_type:
for _, rel in _related_non_m2m_objects(new_temp_field, new_field):
if rel.field.db_constraint:
self.execute(
self._create_fk_sql(rel.related_model, rel.field, "_fk")
)
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
auto_field_types = {"AutoField", "BigAutoField", "SmallAutoField"}
auto_field_types = {'AutoField', 'BigAutoField', 'SmallAutoField'}
# Drop the identity if migrating away from AutoField.
if (
old_field.get_internal_type() in auto_field_types
and new_field.get_internal_type() not in auto_field_types
and self._is_identity_column(model._meta.db_table, new_field.column)
old_field.get_internal_type() in auto_field_types and
new_field.get_internal_type() not in auto_field_types and
self._is_identity_column(model._meta.db_table, new_field.column)
):
self._drop_identity(model._meta.db_table, new_field.column)
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
@@ -190,55 +157,42 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _field_should_be_indexed(self, model, field):
create_index = super()._field_should_be_indexed(model, field)
db_type = field.db_type(self.connection)
if (
db_type is not None
and db_type.lower() in self.connection._limited_data_types
):
if db_type is not None and db_type.lower() in self.connection._limited_data_types:
return False
return create_index
def _unique_should_be_added(self, old_field, new_field):
return super()._unique_should_be_added(
old_field, new_field
) and not self._field_became_primary_key(old_field, new_field)
return (
super()._unique_should_be_added(old_field, new_field) and
not self._field_became_primary_key(old_field, new_field)
)
def _is_identity_column(self, table_name, column_name):
with self.connection.cursor() as cursor:
cursor.execute(
"""
cursor.execute("""
SELECT
CASE WHEN identity_column = 'YES' THEN 1 ELSE 0 END
FROM user_tab_cols
WHERE table_name = %s AND
column_name = %s
""",
[self.normalize_name(table_name), self.normalize_name(column_name)],
)
""", [self.normalize_name(table_name), self.normalize_name(column_name)])
row = cursor.fetchone()
return row[0] if row else False
def _drop_identity(self, table_name, column_name):
self.execute(
"ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY"
% {
"table": self.quote_name(table_name),
"column": self.quote_name(column_name),
}
)
self.execute('ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY' % {
'table': self.quote_name(table_name),
'column': self.quote_name(column_name),
})
def _get_default_collation(self, table_name):
with self.connection.cursor() as cursor:
cursor.execute(
"""
cursor.execute("""
SELECT default_collation FROM user_tables WHERE table_name = %s
""",
[self.normalize_name(table_name)],
)
""", [self.normalize_name(table_name)])
return cursor.fetchone()[0]
def _alter_column_collation_sql(self, model, new_field, new_type, new_collation):
if new_collation is None:
new_collation = self._get_default_collation(model._meta.db_table)
return super()._alter_column_collation_sql(
model, new_field, new_type, new_collation
)
return super()._alter_column_collation_sql(model, new_field, new_type, new_collation)
@@ -9,25 +9,24 @@ class InsertVar:
as a parameter, in order to receive the id of the row created by an
insert statement.
"""
types = {
"AutoField": int,
"BigAutoField": int,
"SmallAutoField": int,
"IntegerField": int,
"BigIntegerField": int,
"SmallIntegerField": int,
"PositiveBigIntegerField": int,
"PositiveSmallIntegerField": int,
"PositiveIntegerField": int,
"FloatField": Database.NATIVE_FLOAT,
"DateTimeField": Database.TIMESTAMP,
"DateField": Database.Date,
"DecimalField": Database.NUMBER,
'AutoField': int,
'BigAutoField': int,
'SmallAutoField': int,
'IntegerField': int,
'BigIntegerField': int,
'SmallIntegerField': int,
'PositiveBigIntegerField': int,
'PositiveSmallIntegerField': int,
'PositiveIntegerField': int,
'FloatField': Database.NATIVE_FLOAT,
'DateTimeField': Database.TIMESTAMP,
'DateField': Database.Date,
'DecimalField': Database.NUMBER,
}
def __init__(self, field):
internal_type = getattr(field, "target_field", field).get_internal_type()
internal_type = getattr(field, 'target_field', field).get_internal_type()
self.db_type = self.types.get(internal_type, str)
self.bound_param = None
@@ -44,54 +43,49 @@ class Oracle_datetime(datetime.datetime):
A datetime object, with an additional class attribute
to tell cx_Oracle to save the microseconds too.
"""
input_size = Database.TIMESTAMP
@classmethod
def from_datetime(cls, dt):
return Oracle_datetime(
dt.year,
dt.month,
dt.day,
dt.hour,
dt.minute,
dt.second,
dt.microsecond,
dt.year, dt.month, dt.day,
dt.hour, dt.minute, dt.second, dt.microsecond,
)
class BulkInsertMapper:
BLOB = "TO_BLOB(%s)"
CLOB = "TO_CLOB(%s)"
DATE = "TO_DATE(%s)"
INTERVAL = "CAST(%s as INTERVAL DAY(9) TO SECOND(6))"
NUMBER = "TO_NUMBER(%s)"
TIMESTAMP = "TO_TIMESTAMP(%s)"
BLOB = 'TO_BLOB(%s)'
CLOB = 'TO_CLOB(%s)'
DATE = 'TO_DATE(%s)'
INTERVAL = 'CAST(%s as INTERVAL DAY(9) TO SECOND(6))'
NUMBER = 'TO_NUMBER(%s)'
TIMESTAMP = 'TO_TIMESTAMP(%s)'
types = {
"AutoField": NUMBER,
"BigAutoField": NUMBER,
"BigIntegerField": NUMBER,
"BinaryField": BLOB,
"BooleanField": NUMBER,
"DateField": DATE,
"DateTimeField": TIMESTAMP,
"DecimalField": NUMBER,
"DurationField": INTERVAL,
"FloatField": NUMBER,
"IntegerField": NUMBER,
"PositiveBigIntegerField": NUMBER,
"PositiveIntegerField": NUMBER,
"PositiveSmallIntegerField": NUMBER,
"SmallAutoField": NUMBER,
"SmallIntegerField": NUMBER,
"TextField": CLOB,
"TimeField": TIMESTAMP,
'AutoField': NUMBER,
'BigAutoField': NUMBER,
'BigIntegerField': NUMBER,
'BinaryField': BLOB,
'BooleanField': NUMBER,
'DateField': DATE,
'DateTimeField': TIMESTAMP,
'DecimalField': NUMBER,
'DurationField': INTERVAL,
'FloatField': NUMBER,
'IntegerField': NUMBER,
'NullBooleanField': NUMBER,
'PositiveBigIntegerField': NUMBER,
'PositiveIntegerField': NUMBER,
'PositiveSmallIntegerField': NUMBER,
'SmallAutoField': NUMBER,
'SmallIntegerField': NUMBER,
'TextField': CLOB,
'TimeField': TIMESTAMP,
}
def dsn(settings_dict):
if settings_dict["PORT"]:
host = settings_dict["HOST"].strip() or "localhost"
return Database.makedsn(host, int(settings_dict["PORT"]), settings_dict["NAME"])
return settings_dict["NAME"]
if settings_dict['PORT']:
host = settings_dict['HOST'].strip() or 'localhost'
return Database.makedsn(host, int(settings_dict['PORT']), settings_dict['NAME'])
return settings_dict['NAME']
@@ -9,14 +9,14 @@ class DatabaseValidation(BaseDatabaseValidation):
if field.db_index and field_type.lower() in self.connection._limited_data_types:
errors.append(
checks.Warning(
"Oracle does not support a database index on %s columns."
'Oracle does not support a database index on %s columns.'
% field_type,
hint=(
"An index won't be created. Silence this warning if "
"you don't care about it."
),
obj=field,
id="fields.W162",
id='fields.W162',
)
)
return errors
@@ -11,10 +11,11 @@ from contextlib import contextmanager
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db import DatabaseError as WrappedDatabaseError
from django.db import connections
from django.db import DatabaseError as WrappedDatabaseError, connections
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
from django.db.backends.utils import (
CursorDebugWrapper as BaseCursorDebugWrapper,
)
from django.utils.asyncio import async_unsafe
from django.utils.functional import cached_property
from django.utils.safestring import SafeString
@@ -29,17 +30,14 @@ except ImportError as e:
def psycopg2_version():
version = psycopg2.__version__.split(" ", 1)[0]
version = psycopg2.__version__.split(' ', 1)[0]
return get_version_tuple(version)
PSYCOPG2_VERSION = psycopg2_version()
if PSYCOPG2_VERSION < (2, 8, 4):
raise ImproperlyConfigured(
"psycopg2 version 2.8.4 or newer is required; you have %s"
% psycopg2.__version__
)
if PSYCOPG2_VERSION < (2, 5, 4):
raise ImproperlyConfigured("psycopg2_version 2.5.4 or newer is required; you have %s" % psycopg2.__version__)
# Some of these import psycopg2, so import them after checking if it's installed.
@@ -58,68 +56,69 @@ psycopg2.extras.register_uuid()
INETARRAY_OID = 1041
INETARRAY = psycopg2.extensions.new_array_type(
(INETARRAY_OID,),
"INETARRAY",
'INETARRAY',
psycopg2.extensions.UNICODE,
)
psycopg2.extensions.register_type(INETARRAY)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "postgresql"
display_name = "PostgreSQL"
vendor = 'postgresql'
display_name = 'PostgreSQL'
# This dictionary maps Field objects to their associated PostgreSQL column
# types, as strings. Column-type strings can contain format strings; they'll
# be interpolated against the values of Field.__dict__ before being output.
# If a column type is set to None, it won't be included in the output.
data_types = {
"AutoField": "serial",
"BigAutoField": "bigserial",
"BinaryField": "bytea",
"BooleanField": "boolean",
"CharField": "varchar(%(max_length)s)",
"DateField": "date",
"DateTimeField": "timestamp with time zone",
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
"DurationField": "interval",
"FileField": "varchar(%(max_length)s)",
"FilePathField": "varchar(%(max_length)s)",
"FloatField": "double precision",
"IntegerField": "integer",
"BigIntegerField": "bigint",
"IPAddressField": "inet",
"GenericIPAddressField": "inet",
"JSONField": "jsonb",
"OneToOneField": "integer",
"PositiveBigIntegerField": "bigint",
"PositiveIntegerField": "integer",
"PositiveSmallIntegerField": "smallint",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "smallserial",
"SmallIntegerField": "smallint",
"TextField": "text",
"TimeField": "time",
"UUIDField": "uuid",
'AutoField': 'serial',
'BigAutoField': 'bigserial',
'BinaryField': 'bytea',
'BooleanField': 'boolean',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'timestamp with time zone',
'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
'DurationField': 'interval',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'double precision',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'inet',
'GenericIPAddressField': 'inet',
'JSONField': 'jsonb',
'NullBooleanField': 'boolean',
'OneToOneField': 'integer',
'PositiveBigIntegerField': 'bigint',
'PositiveIntegerField': 'integer',
'PositiveSmallIntegerField': 'smallint',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'smallserial',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'UUIDField': 'uuid',
}
data_type_check_constraints = {
"PositiveBigIntegerField": '"%(column)s" >= 0',
"PositiveIntegerField": '"%(column)s" >= 0',
"PositiveSmallIntegerField": '"%(column)s" >= 0',
'PositiveBigIntegerField': '"%(column)s" >= 0',
'PositiveIntegerField': '"%(column)s" >= 0',
'PositiveSmallIntegerField': '"%(column)s" >= 0',
}
operators = {
"exact": "= %s",
"iexact": "= UPPER(%s)",
"contains": "LIKE %s",
"icontains": "LIKE UPPER(%s)",
"regex": "~ %s",
"iregex": "~* %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": "LIKE %s",
"endswith": "LIKE %s",
"istartswith": "LIKE UPPER(%s)",
"iendswith": "LIKE UPPER(%s)",
'exact': '= %s',
'iexact': '= UPPER(%s)',
'contains': 'LIKE %s',
'icontains': 'LIKE UPPER(%s)',
'regex': '~ %s',
'iregex': '~* %s',
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': 'LIKE %s',
'endswith': 'LIKE %s',
'istartswith': 'LIKE UPPER(%s)',
'iendswith': 'LIKE UPPER(%s)',
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -130,16 +129,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
#
# Note: we use str.format() here for readability as '%' is used as a wildcard for
# the LIKE operator.
pattern_esc = (
r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
)
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
pattern_ops = {
"contains": "LIKE '%%' || {} || '%%'",
"icontains": "LIKE '%%' || UPPER({}) || '%%'",
"startswith": "LIKE {} || '%%'",
"istartswith": "LIKE UPPER({}) || '%%'",
"endswith": "LIKE '%%' || {}",
"iendswith": "LIKE '%%' || UPPER({})",
'contains': "LIKE '%%' || {} || '%%'",
'icontains': "LIKE '%%' || UPPER({}) || '%%'",
'startswith': "LIKE {} || '%%'",
'istartswith': "LIKE UPPER({}) || '%%'",
'endswith': "LIKE '%%' || {}",
'iendswith': "LIKE '%%' || UPPER({})",
}
Database = Database
@@ -156,46 +153,33 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def get_connection_params(self):
settings_dict = self.settings_dict
# None may be used to connect to the default 'postgres' db
if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
"service"
):
if settings_dict['NAME'] == '':
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME or OPTIONS['service'] value."
)
if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
"Please supply the NAME value.")
if len(settings_dict['NAME'] or '') > self.ops.max_name_length():
raise ImproperlyConfigured(
"The database name '%s' (%d characters) is longer than "
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
"in settings.DATABASES."
% (
settings_dict["NAME"],
len(settings_dict["NAME"]),
"in settings.DATABASES." % (
settings_dict['NAME'],
len(settings_dict['NAME']),
self.ops.max_name_length(),
)
)
conn_params = {}
if settings_dict["NAME"]:
conn_params = {
"database": settings_dict["NAME"],
**settings_dict["OPTIONS"],
}
elif settings_dict["NAME"] is None:
# Connect to the default 'postgres' db.
settings_dict.get("OPTIONS", {}).pop("service", None)
conn_params = {"database": "postgres", **settings_dict["OPTIONS"]}
else:
conn_params = {**settings_dict["OPTIONS"]}
conn_params.pop("isolation_level", None)
if settings_dict["USER"]:
conn_params["user"] = settings_dict["USER"]
if settings_dict["PASSWORD"]:
conn_params["password"] = settings_dict["PASSWORD"]
if settings_dict["HOST"]:
conn_params["host"] = settings_dict["HOST"]
if settings_dict["PORT"]:
conn_params["port"] = settings_dict["PORT"]
conn_params = {
'database': settings_dict['NAME'] or 'postgres',
**settings_dict['OPTIONS'],
}
conn_params.pop('isolation_level', None)
if settings_dict['USER']:
conn_params['user'] = settings_dict['USER']
if settings_dict['PASSWORD']:
conn_params['password'] = settings_dict['PASSWORD']
if settings_dict['HOST']:
conn_params['host'] = settings_dict['HOST']
if settings_dict['PORT']:
conn_params['port'] = settings_dict['PORT']
return conn_params
@async_unsafe
@@ -207,9 +191,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# default when no value is explicitly specified in options.
# - before calling _set_autocommit() because if autocommit is on, that
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
options = self.settings_dict["OPTIONS"]
options = self.settings_dict['OPTIONS']
try:
self.isolation_level = options["isolation_level"]
self.isolation_level = options['isolation_level']
except KeyError:
self.isolation_level = connection.isolation_level
else:
@@ -219,15 +203,13 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# Register dummy loads() to avoid a round trip from psycopg2's decode
# to json.dumps() to json.loads(), when using a custom decoder in
# JSONField.
psycopg2.extras.register_default_jsonb(
conn_or_curs=connection, loads=lambda x: x
)
psycopg2.extras.register_default_jsonb(conn_or_curs=connection, loads=lambda x: x)
return connection
def ensure_timezone(self):
if self.connection is None:
return False
conn_timezone_name = self.connection.get_parameter_status("TimeZone")
conn_timezone_name = self.connection.get_parameter_status('TimeZone')
timezone_name = self.timezone_name
if timezone_name and conn_timezone_name != timezone_name:
with self.connection.cursor() as cursor:
@@ -236,7 +218,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
return False
def init_connection_state(self):
self.connection.set_client_encoding("UTF8")
self.connection.set_client_encoding('UTF8')
timezone_changed = self.ensure_timezone()
if timezone_changed:
@@ -249,9 +231,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if name:
# In autocommit mode, the cursor will be used outside of a
# transaction, hence use a holdable cursor.
cursor = self.connection.cursor(
name, scrollable=False, withhold=self.connection.autocommit
)
cursor = self.connection.cursor(name, scrollable=False, withhold=self.connection.autocommit)
else:
cursor = self.connection.cursor()
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
@@ -269,18 +249,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# For now, it's here so that every use of "threading" is
# also async-compatible.
try:
current_task = asyncio.current_task()
if hasattr(asyncio, 'current_task'):
# Python 3.7 and up
current_task = asyncio.current_task()
else:
# Python 3.6
current_task = asyncio.Task.current_task()
except RuntimeError:
current_task = None
# Current task can be none even if the current_task call didn't error
if current_task:
task_ident = str(id(current_task))
else:
task_ident = "sync"
task_ident = 'sync'
# Use that and the thread ident to get a unique name
return self._cursor(
name="_django_curs_%d_%s_%d"
% (
name='_django_curs_%d_%s_%d' % (
# Avoid reusing name in other threads / tasks
threading.current_thread().ident,
task_ident,
@@ -298,14 +282,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
afterward.
"""
with self.cursor() as cursor:
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
cursor.execute('SET CONSTRAINTS ALL IMMEDIATE')
cursor.execute('SET CONSTRAINTS ALL DEFERRED')
def is_usable(self):
try:
# Use a psycopg cursor directly, bypassing Django's utilities.
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.execute('SELECT 1')
except Database.Error:
return False
else:
@@ -313,31 +297,22 @@ class DatabaseWrapper(BaseDatabaseWrapper):
@contextmanager
def _nodb_cursor(self):
cursor = None
try:
with super()._nodb_cursor() as cursor:
yield cursor
except (Database.DatabaseError, WrappedDatabaseError):
if cursor is not None:
raise
warnings.warn(
"Normally Django will use a connection to the 'postgres' database "
"to avoid running initialization queries against the production "
"database when it's not needed (for example, when running tests). "
"Django was unable to create a connection to the 'postgres' database "
"and will use the first PostgreSQL database instead.",
RuntimeWarning,
RuntimeWarning
)
for connection in connections.all():
if (
connection.vendor == "postgresql"
and connection.settings_dict["NAME"] != "postgres"
):
if connection.vendor == 'postgresql' and connection.settings_dict['NAME'] != 'postgres':
conn = self.__class__(
{
**self.settings_dict,
"NAME": connection.settings_dict["NAME"],
},
{**self.settings_dict, 'NAME': connection.settings_dict['NAME']},
alias=self.alias,
)
try:
@@ -364,5 +339,5 @@ class CursorDebugWrapper(BaseCursorDebugWrapper):
return self.cursor.copy_expert(sql, file, *args)
def copy_to(self, file, table, *args, **kwargs):
with self.debug_sql(sql="COPY %s TO STDOUT" % table):
with self.debug_sql(sql='COPY %s TO STDOUT' % table):
return self.cursor.copy_to(file, table, *args, **kwargs)
@@ -4,53 +4,43 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "psql"
executable_name = 'psql'
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
options = settings_dict.get("OPTIONS", {})
options = settings_dict.get('OPTIONS', {})
host = settings_dict.get("HOST")
port = settings_dict.get("PORT")
dbname = settings_dict.get("NAME")
user = settings_dict.get("USER")
passwd = settings_dict.get("PASSWORD")
passfile = options.get("passfile")
service = options.get("service")
sslmode = options.get("sslmode")
sslrootcert = options.get("sslrootcert")
sslcert = options.get("sslcert")
sslkey = options.get("sslkey")
host = settings_dict.get('HOST')
port = settings_dict.get('PORT')
dbname = settings_dict.get('NAME') or 'postgres'
user = settings_dict.get('USER')
passwd = settings_dict.get('PASSWORD')
sslmode = options.get('sslmode')
sslrootcert = options.get('sslrootcert')
sslcert = options.get('sslcert')
sslkey = options.get('sslkey')
if not dbname and not service:
# Connect to the default 'postgres' db.
dbname = "postgres"
if user:
args += ["-U", user]
args += ['-U', user]
if host:
args += ["-h", host]
args += ['-h', host]
if port:
args += ["-p", str(port)]
if dbname:
args += [dbname]
args += ['-p', str(port)]
args += [dbname]
args.extend(parameters)
env = {}
if passwd:
env["PGPASSWORD"] = str(passwd)
if service:
env["PGSERVICE"] = str(service)
env['PGPASSWORD'] = str(passwd)
if sslmode:
env["PGSSLMODE"] = str(sslmode)
env['PGSSLMODE'] = str(sslmode)
if sslrootcert:
env["PGSSLROOTCERT"] = str(sslrootcert)
env['PGSSLROOTCERT'] = str(sslrootcert)
if sslcert:
env["PGSSLCERT"] = str(sslcert)
env['PGSSLCERT'] = str(sslcert)
if sslkey:
env["PGSSLKEY"] = str(sslkey)
if passfile:
env["PGPASSFILE"] = str(passfile)
env['PGSSLKEY'] = str(sslkey)
return args, (env or None)
def runshell(self, parameters):
@@ -2,12 +2,12 @@ import sys
from psycopg2 import errorcodes
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.base.creation import BaseDatabaseCreation
from django.db.backends.utils import strip_quotes
class DatabaseCreation(BaseDatabaseCreation):
def _quote_name(self, name):
return self.connection.ops.quote_name(name)
@@ -20,35 +20,30 @@ class DatabaseCreation(BaseDatabaseCreation):
return suffix and "WITH" + suffix
def sql_table_creation_suffix(self):
test_settings = self.connection.settings_dict["TEST"]
if test_settings.get("COLLATION") is not None:
raise ImproperlyConfigured(
"PostgreSQL does not support collation setting at database "
"creation time."
)
test_settings = self.connection.settings_dict['TEST']
assert test_settings['COLLATION'] is None, (
"PostgreSQL does not support collation setting at database creation time."
)
return self._get_database_create_suffix(
encoding=test_settings["CHARSET"],
template=test_settings.get("TEMPLATE"),
encoding=test_settings['CHARSET'],
template=test_settings.get('TEMPLATE'),
)
def _database_exists(self, cursor, database_name):
cursor.execute(
"SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s",
[strip_quotes(database_name)],
)
cursor.execute('SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s', [strip_quotes(database_name)])
return cursor.fetchone() is not None
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
try:
if keepdb and self._database_exists(cursor, parameters["dbname"]):
if keepdb and self._database_exists(cursor, parameters['dbname']):
# If the database should be kept and it already exists, don't
# try to create a new one.
return
super()._execute_create_test_db(cursor, parameters, keepdb)
except Exception as e:
if getattr(e.__cause__, "pgcode", "") != errorcodes.DUPLICATE_DATABASE:
if getattr(e.__cause__, 'pgcode', '') != errorcodes.DUPLICATE_DATABASE:
# All errors except "database already exists" cancel tests.
self.log("Got an error creating the test database: %s" % e)
self.log('Got an error creating the test database: %s' % e)
sys.exit(2)
elif not keepdb:
# If the database should be kept, ignore "database already
@@ -60,11 +55,11 @@ class DatabaseCreation(BaseDatabaseCreation):
# to the template database.
self.connection.close()
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
test_db_params = {
"dbname": self._quote_name(target_database_name),
"suffix": self._get_database_create_suffix(template=source_database_name),
'dbname': self._quote_name(target_database_name),
'suffix': self._get_database_create_suffix(template=source_database_name),
}
with self._nodb_cursor() as cursor:
try:
@@ -72,16 +67,11 @@ class DatabaseCreation(BaseDatabaseCreation):
except Exception:
try:
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, target_database_name
),
)
)
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
cursor.execute('DROP DATABASE %(dbname)s' % test_db_params)
self._execute_create_test_db(cursor, test_db_params, keepdb)
except Exception as e:
self.log("Got an error cloning the test database: %s" % e)
self.log('Got an error cloning the test database: %s' % e)
sys.exit(2)
@@ -53,32 +53,41 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_over_clause = True
only_supports_unbounded_with_preceding_and_following = True
supports_aggregate_filter_clause = True
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
supported_explain_formats = {'JSON', 'TEXT', 'XML', 'YAML'}
validates_explain_options = False # A query will error on invalid options.
supports_deferrable_unique_constraints = True
has_json_operators = True
json_key_contains_list_matching_requires_list = True
test_collations = {
"non_default": "sv-x-icu",
"swedish_ci": "sv-x-icu",
}
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
django_test_skips = {
"opclasses are PostgreSQL only.": {
"indexes.tests.SchemaIndexesNotPostgreSQLTests."
"test_create_index_ignores_opclasses",
'opclasses are PostgreSQL only.': {
'indexes.tests.SchemaIndexesNotPostgreSQLTests.test_create_index_ignores_opclasses',
},
}
@cached_property
def test_collations(self):
# PostgreSQL < 10 doesn't support ICU collations.
if self.is_postgresql_10:
return {
'non_default': 'sv-x-icu',
'swedish_ci': 'sv-x-icu',
}
return {}
@cached_property
def introspected_field_types(self):
return {
**super().introspected_field_types,
"PositiveBigIntegerField": "BigIntegerField",
"PositiveIntegerField": "IntegerField",
"PositiveSmallIntegerField": "SmallIntegerField",
'PositiveBigIntegerField': 'BigIntegerField',
'PositiveIntegerField': 'IntegerField',
'PositiveSmallIntegerField': 'SmallIntegerField',
}
@cached_property
def is_postgresql_10(self):
return self.connection.pg_version >= 100000
@cached_property
def is_postgresql_11(self):
return self.connection.pg_version >= 110000
@@ -91,9 +100,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
def is_postgresql_13(self):
return self.connection.pg_version >= 130000
has_websearch_to_tsquery = property(operator.attrgetter("is_postgresql_11"))
supports_covering_indexes = property(operator.attrgetter("is_postgresql_11"))
supports_covering_gist_indexes = property(operator.attrgetter("is_postgresql_12"))
supports_non_deterministic_collations = property(
operator.attrgetter("is_postgresql_12")
)
has_brin_autosummarize = property(operator.attrgetter('is_postgresql_10'))
has_websearch_to_tsquery = property(operator.attrgetter('is_postgresql_11'))
supports_table_partitions = property(operator.attrgetter('is_postgresql_10'))
supports_covering_indexes = property(operator.attrgetter('is_postgresql_11'))
supports_covering_gist_indexes = property(operator.attrgetter('is_postgresql_12'))
supports_non_deterministic_collations = property(operator.attrgetter('is_postgresql_12'))
supports_alternate_collation_providers = property(operator.attrgetter('is_postgresql_10'))
@@ -1,7 +1,5 @@
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection,
FieldInfo,
TableInfo,
BaseDatabaseIntrospection, FieldInfo, TableInfo,
)
from django.db.models import Index
@@ -9,66 +7,55 @@ from django.db.models import Index
class DatabaseIntrospection(BaseDatabaseIntrospection):
# Maps type codes to Django Field types.
data_types_reverse = {
16: "BooleanField",
17: "BinaryField",
20: "BigIntegerField",
21: "SmallIntegerField",
23: "IntegerField",
25: "TextField",
700: "FloatField",
701: "FloatField",
869: "GenericIPAddressField",
1042: "CharField", # blank-padded
1043: "CharField",
1082: "DateField",
1083: "TimeField",
1114: "DateTimeField",
1184: "DateTimeField",
1186: "DurationField",
1266: "TimeField",
1700: "DecimalField",
2950: "UUIDField",
3802: "JSONField",
16: 'BooleanField',
17: 'BinaryField',
20: 'BigIntegerField',
21: 'SmallIntegerField',
23: 'IntegerField',
25: 'TextField',
700: 'FloatField',
701: 'FloatField',
869: 'GenericIPAddressField',
1042: 'CharField', # blank-padded
1043: 'CharField',
1082: 'DateField',
1083: 'TimeField',
1114: 'DateTimeField',
1184: 'DateTimeField',
1186: 'DurationField',
1266: 'TimeField',
1700: 'DecimalField',
2950: 'UUIDField',
3802: 'JSONField',
}
# A hook for subclasses.
index_default_access_method = "btree"
index_default_access_method = 'btree'
ignored_tables = []
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.default and "nextval" in description.default:
if field_type == "IntegerField":
return "AutoField"
elif field_type == "BigIntegerField":
return "BigAutoField"
elif field_type == "SmallIntegerField":
return "SmallAutoField"
if description.default and 'nextval' in description.default:
if field_type == 'IntegerField':
return 'AutoField'
elif field_type == 'BigIntegerField':
return 'BigAutoField'
elif field_type == 'SmallIntegerField':
return 'SmallAutoField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
cursor.execute(
"""
SELECT
c.relname,
CASE
WHEN c.relispartition THEN 'p'
WHEN c.relkind IN ('m', 'v') THEN 'v'
ELSE 't'
END
cursor.execute("""
SELECT c.relname,
CASE WHEN {} THEN 'p' WHEN c.relkind IN ('m', 'v') THEN 'v' ELSE 't' END
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
"""
)
return [
TableInfo(*row)
for row in cursor.fetchall()
if row[0] not in self.ignored_tables
]
""".format('c.relispartition' if self.connection.features.supports_table_partitions else 'FALSE'))
return [TableInfo(*row) for row in cursor.fetchall() if row[0] not in self.ignored_tables]
def get_table_description(self, cursor, table_name):
"""
@@ -78,8 +65,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Query the pg_catalog tables as cursor.description does not reliably
# return the nullable property and information_schema.columns does not
# contain details of materialized views.
cursor.execute(
"""
cursor.execute("""
SELECT
a.attname AS column_name,
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
@@ -95,13 +81,9 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
AND c.relname = %s
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid)
""",
[table_name],
)
""", [table_name])
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
cursor.execute(
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
)
cursor.execute("SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name))
return [
FieldInfo(
line.name,
@@ -116,30 +98,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
]
def get_sequences(self, cursor, table_name, table_fields=()):
cursor.execute(
"""
cursor.execute("""
SELECT s.relname as sequence_name, col.attname
FROM pg_class s
JOIN pg_namespace sn ON sn.oid = s.relnamespace
JOIN
pg_depend d ON d.refobjid = s.oid
AND d.refclassid = 'pg_class'::regclass
JOIN
pg_attrdef ad ON ad.oid = d.objid
AND d.classid = 'pg_attrdef'::regclass
JOIN
pg_attribute col ON col.attrelid = ad.adrelid
AND col.attnum = ad.adnum
JOIN pg_depend d ON d.refobjid = s.oid AND d.refclassid = 'pg_class'::regclass
JOIN pg_attrdef ad ON ad.oid = d.objid AND d.classid = 'pg_attrdef'::regclass
JOIN pg_attribute col ON col.attrelid = ad.adrelid AND col.attnum = ad.adnum
JOIN pg_class tbl ON tbl.oid = ad.adrelid
WHERE s.relkind = 'S'
AND d.deptype in ('a', 'n')
AND pg_catalog.pg_table_is_visible(tbl.oid)
AND tbl.relname = %s
""",
[table_name],
)
""", [table_name])
return [
{"name": row[0], "table": table_name, "column": row[1]}
{'name': row[0], 'table': table_name, 'column': row[1]}
for row in cursor.fetchall()
]
@@ -148,29 +121,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
Return a dictionary of {field_name: (field_name_other_table, other_table)}
representing all relationships to the given table.
"""
return {
row[0]: (row[2], row[1]) for row in self.get_key_columns(cursor, table_name)
}
return {row[0]: (row[2], row[1]) for row in self.get_key_columns(cursor, table_name)}
def get_key_columns(self, cursor, table_name):
cursor.execute(
"""
cursor.execute("""
SELECT a1.attname, c2.relname, a2.attname
FROM pg_constraint con
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
LEFT JOIN
pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
LEFT JOIN
pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
WHERE
c1.relname = %s AND
con.contype = 'f' AND
c1.relnamespace = c2.relnamespace AND
pg_catalog.pg_table_is_visible(c1.oid)
""",
[table_name],
)
""", [table_name])
return cursor.fetchall()
def get_constraints(self, cursor, table_name):
@@ -183,8 +149,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Loop over the key table, collecting things as constraints. The column
# array must return column names in the same order in which they were
# created.
cursor.execute(
"""
cursor.execute("""
SELECT
c.conname,
array(
@@ -203,9 +168,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
FROM pg_constraint AS c
JOIN pg_class AS cl ON c.conrelid = cl.oid
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
""",
[table_name],
)
""", [table_name])
for constraint, columns, kind, used_cols, options in cursor.fetchall():
constraints[constraint] = {
"columns": columns,
@@ -218,17 +181,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"options": options,
}
# Now get indexes
cursor.execute(
"""
cursor.execute("""
SELECT
indexname,
array_agg(attname ORDER BY arridx),
indisunique,
indisprimary,
array_agg(ordering ORDER BY arridx),
amname,
exprdef,
s2.attoptions
indexname, array_agg(attname ORDER BY arridx), indisunique, indisprimary,
array_agg(ordering ORDER BY arridx), amname, exprdef, s2.attoptions
FROM (
SELECT
c2.relname as indexname, idx.*, attr.attname, am.amname,
@@ -245,40 +201,23 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
c2.reloptions as attoptions
FROM (
SELECT *
FROM
pg_index i,
unnest(i.indkey, i.indoption)
WITH ORDINALITY koi(key, option, arridx)
FROM pg_index i, unnest(i.indkey, i.indoption) WITH ORDINALITY koi(key, option, arridx)
) idx
LEFT JOIN pg_class c ON idx.indrelid = c.oid
LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
LEFT JOIN pg_am am ON c2.relam = am.oid
LEFT JOIN
pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
LEFT JOIN pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
) s2
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
""",
[self.index_default_access_method, table_name],
)
for (
index,
columns,
unique,
primary,
orders,
type_,
definition,
options,
) in cursor.fetchall():
""", [self.index_default_access_method, table_name])
for index, columns, unique, primary, orders, type_, definition, options in cursor.fetchall():
if index not in constraints:
basic_index = (
type_ == self.index_default_access_method
and
type_ == self.index_default_access_method and
# '_btree' references
# django.contrib.postgres.indexes.BTreeIndex.suffix.
not index.endswith("_btree")
and options is None
not index.endswith('_btree') and options is None
)
constraints[index] = {
"columns": columns if columns != [None] else [],
@@ -2,38 +2,20 @@ from psycopg2.extras import Inet
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = "varchar"
explain_prefix = "EXPLAIN"
explain_options = frozenset(
[
"ANALYZE",
"BUFFERS",
"COSTS",
"SETTINGS",
"SUMMARY",
"TIMING",
"VERBOSE",
"WAL",
]
)
cast_char_field_without_max_length = 'varchar'
explain_prefix = 'EXPLAIN'
cast_data_types = {
"AutoField": "integer",
"BigAutoField": "bigint",
"SmallAutoField": "smallint",
'AutoField': 'integer',
'BigAutoField': 'bigint',
'SmallAutoField': 'smallint',
}
def unification_cast_sql(self, output_field):
internal_type = output_field.get_internal_type()
if internal_type in (
"GenericIPAddressField",
"IPAddressField",
"TimeField",
"UUIDField",
):
if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
# PostgreSQL will resolve a union as type 'text' if input types are
# 'unknown'.
# https://www.postgresql.org/docs/current/typeconv-union-case.html
@@ -41,19 +23,17 @@ class DatabaseOperations(BaseDatabaseOperations):
# PostgreSQL configuration so we need to explicitly cast them.
# We must also remove components of the type within brackets:
# varchar(255) -> varchar.
return (
"CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0]
)
return "%s"
return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
return '%s'
def date_extract_sql(self, lookup_type, field_name):
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
if lookup_type == "week_day":
if lookup_type == 'week_day':
# For consistency across backends, we return Sunday=1, Saturday=7.
return "EXTRACT('dow' FROM %s) + 1" % field_name
elif lookup_type == "iso_week_day":
elif lookup_type == 'iso_week_day':
return "EXTRACT('isodow' FROM %s)" % field_name
elif lookup_type == "iso_year":
elif lookup_type == 'iso_year':
return "EXTRACT('isoyear' FROM %s)" % field_name
else:
return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
@@ -64,27 +44,24 @@ class DatabaseOperations(BaseDatabaseOperations):
return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
def _prepare_tzname_delta(self, tzname):
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
sign = "-" if sign == "+" else "+"
return f"{tzname}{sign}{offset}"
if '+' in tzname:
return tzname.replace('+', '-')
elif '-' in tzname:
return tzname.replace('-', '+')
return tzname
def _convert_field_to_tz(self, field_name, tzname):
if tzname and settings.USE_TZ:
field_name = "%s AT TIME ZONE '%s'" % (
field_name,
self._prepare_tzname_delta(tzname),
)
field_name = "%s AT TIME ZONE '%s'" % (field_name, self._prepare_tzname_delta(tzname))
return field_name
def datetime_cast_date_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "(%s)::date" % field_name
return '(%s)::date' % field_name
def datetime_cast_time_sql(self, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
return "(%s)::time" % field_name
return '(%s)::time' % field_name
def datetime_extract_sql(self, lookup_type, field_name, tzname):
field_name = self._convert_field_to_tz(field_name, tzname)
@@ -110,30 +87,21 @@ class DatabaseOperations(BaseDatabaseOperations):
return cursor.fetchall()
def lookup_cast(self, lookup_type, internal_type=None):
lookup = "%s"
lookup = '%s'
# Cast text lookups to text to allow things like filter(x__contains=4)
if lookup_type in (
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"regex",
"iregex",
):
if internal_type in ("IPAddressField", "GenericIPAddressField"):
if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
if internal_type in ('IPAddressField', 'GenericIPAddressField'):
lookup = "HOST(%s)"
elif internal_type in ("CICharField", "CIEmailField", "CITextField"):
lookup = "%s::citext"
elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'):
lookup = '%s::citext'
else:
lookup = "%s::text"
# Use UPPER(x) for case-insensitive lookups; it's faster.
if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
lookup = "UPPER(%s)" % lookup
if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
lookup = 'UPPER(%s)' % lookup
return lookup
@@ -158,32 +126,29 @@ class DatabaseOperations(BaseDatabaseOperations):
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
# to truncate tables referenced by a foreign key in any other table.
sql_parts = [
style.SQL_KEYWORD("TRUNCATE"),
", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
style.SQL_KEYWORD('TRUNCATE'),
', '.join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
]
if reset_sequences:
sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY"))
sql_parts.append(style.SQL_KEYWORD('RESTART IDENTITY'))
if allow_cascade:
sql_parts.append(style.SQL_KEYWORD("CASCADE"))
return ["%s;" % " ".join(sql_parts)]
sql_parts.append(style.SQL_KEYWORD('CASCADE'))
return ['%s;' % ' '.join(sql_parts)]
def sequence_reset_by_name_sql(self, style, sequences):
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
# to reset sequence indices
sql = []
for sequence_info in sequences:
table_name = sequence_info["table"]
table_name = sequence_info['table']
# 'id' will be the case if it's an m2m using an autogenerated
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
column_name = sequence_info["column"] or "id"
sql.append(
"%s setval(pg_get_serial_sequence('%s','%s'), 1, false);"
% (
style.SQL_KEYWORD("SELECT"),
style.SQL_TABLE(self.quote_name(table_name)),
style.SQL_FIELD(column_name),
)
)
column_name = sequence_info['column'] or 'id'
sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % (
style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(self.quote_name(table_name)),
style.SQL_FIELD(column_name),
))
return sql
def tablespace_sql(self, tablespace, inline=False):
@@ -194,36 +159,31 @@ class DatabaseOperations(BaseDatabaseOperations):
def sequence_reset_sql(self, style, model_list):
from django.db import models
output = []
qn = self.quote_name
for model in model_list:
# Use `coalesce` to set the sequence for each model to the max pk
# value if there are records, or 1 if there are none. Set the
# `is_called` property (the third argument to `setval`) to true if
# there are records (as the max pk value is already in use),
# otherwise set it to false. Use pg_get_serial_sequence to get the
# underlying sequence name from the table name and column name.
# Use `coalesce` to set the sequence for each model to the max pk value if there are records,
# or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
# if there are records (as the max pk value is already in use), otherwise set it to false.
# Use pg_get_serial_sequence to get the underlying sequence name from the table name
# and column name (available since PostgreSQL 8)
for f in model._meta.local_fields:
if isinstance(f, models.AutoField):
output.append(
"%s setval(pg_get_serial_sequence('%s','%s'), "
"coalesce(max(%s), 1), max(%s) %s null) %s %s;"
% (
style.SQL_KEYWORD("SELECT"),
"coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
style.SQL_KEYWORD('SELECT'),
style.SQL_TABLE(qn(model._meta.db_table)),
style.SQL_FIELD(f.column),
style.SQL_FIELD(qn(f.column)),
style.SQL_FIELD(qn(f.column)),
style.SQL_KEYWORD("IS NOT"),
style.SQL_KEYWORD("FROM"),
style.SQL_KEYWORD('IS NOT'),
style.SQL_KEYWORD('FROM'),
style.SQL_TABLE(qn(model._meta.db_table)),
)
)
# Only one AutoField is allowed per model, so don't bother
# continuing.
break
break # Only one AutoField is allowed per model, so don't bother continuing.
return output
def prep_for_iexact_query(self, x):
@@ -245,9 +205,9 @@ class DatabaseOperations(BaseDatabaseOperations):
def distinct_sql(self, fields, params):
if fields:
params = [param for param_list in params for param in param_list]
return (["DISTINCT ON (%s)" % ", ".join(fields)], params)
return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
else:
return ["DISTINCT"], []
return ['DISTINCT'], []
def last_executed_query(self, cursor, sql, params):
# https://www.psycopg.org/docs/cursor.html#cursor.query
@@ -258,16 +218,14 @@ class DatabaseOperations(BaseDatabaseOperations):
def return_insert_columns(self, fields):
if not fields:
return "", ()
return '', ()
columns = [
"%s.%s"
% (
'%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
) for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()
return 'RETURNING %s' % ', '.join(columns), ()
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
@@ -292,7 +250,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
def subtract_temporals(self, internal_type, lhs, rhs):
if internal_type == "DateField":
if internal_type == 'DateField':
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
@@ -300,27 +258,18 @@ class DatabaseOperations(BaseDatabaseOperations):
return super().subtract_temporals(internal_type, lhs, rhs)
def explain_query_prefix(self, format=None, **options):
prefix = super().explain_query_prefix(format)
extra = {}
# Normalize options.
if options:
options = {
name.upper(): "true" if value else "false"
for name, value in options.items()
}
for valid_option in self.explain_options:
value = options.pop(valid_option, None)
if value is not None:
extra[valid_option.upper()] = value
prefix = super().explain_query_prefix(format, **options)
if format:
extra["FORMAT"] = format
extra['FORMAT'] = format
if options:
extra.update({
name.upper(): 'true' if value else 'false'
for name, value in options.items()
})
if extra:
prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items())
prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
return prefix
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
return (
"ON CONFLICT DO NOTHING"
if ignore_conflicts
else super().ignore_conflicts_suffix_sql(ignore_conflicts)
)
return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)
@@ -9,18 +9,16 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
sql_set_sequence_max = (
"SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
)
sql_set_sequence_owner = "ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s"
sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
sql_set_sequence_owner = 'ALTER SEQUENCE %(sequence)s OWNED BY %(table)s.%(column)s'
sql_create_index = (
"CREATE INDEX %(name)s ON %(table)s%(using)s "
"(%(columns)s)%(include)s%(extra)s%(condition)s"
'CREATE INDEX %(name)s ON %(table)s%(using)s '
'(%(columns)s)%(include)s%(extra)s%(condition)s'
)
sql_create_index_concurrently = (
"CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
"(%(columns)s)%(include)s%(extra)s%(condition)s"
'CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s '
'(%(columns)s)%(include)s%(extra)s%(condition)s'
)
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
@@ -28,23 +26,21 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Setting the constraint to IMMEDIATE to allow changing data in the same
# transaction.
sql_create_column_inline_fk = (
"CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
"; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
'CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s'
'; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE'
)
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
# dropping it in the same transaction.
sql_delete_fk = (
"SET CONSTRAINTS %(name)s IMMEDIATE; "
"ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
)
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
sql_delete_fk = "SET CONSTRAINTS %(name)s IMMEDIATE; ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
sql_delete_procedure = 'DROP FUNCTION %(procedure)s(%(param_types)s)'
def quote_value(self, value):
if isinstance(value, str):
value = value.replace("%", "%%")
value = value.replace('%', '%%')
adapted = psycopg2.extensions.adapt(value)
if hasattr(adapted, "encoding"):
adapted.encoding = "utf8"
if hasattr(adapted, 'encoding'):
adapted.encoding = 'utf8'
# getquoted() returns a quoted bytestring of the adapted value.
return adapted.getquoted().decode()
@@ -65,7 +61,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _field_base_data_types(self, field):
# Yield base data types for array fields.
if field.base_field.get_internal_type() == "ArrayField":
if field.base_field.get_internal_type() == 'ArrayField':
yield from self._field_base_data_types(field.base_field)
else:
yield self._field_data_type(field.base_field)
@@ -84,52 +80,45 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
#
# The same doesn't apply to array fields such as varchar[size]
# and text[size], so skip them.
if "[" in db_type:
if '[' in db_type:
return None
if db_type.startswith("varchar"):
if db_type.startswith('varchar'):
return self._create_index_sql(
model,
fields=[field],
suffix="_like",
opclasses=["varchar_pattern_ops"],
suffix='_like',
opclasses=['varchar_pattern_ops'],
)
elif db_type.startswith("text"):
elif db_type.startswith('text'):
return self._create_index_sql(
model,
fields=[field],
suffix="_like",
opclasses=["text_pattern_ops"],
suffix='_like',
opclasses=['text_pattern_ops'],
)
return None
def _alter_column_type_sql(self, model, old_field, new_field, new_type):
self.sql_alter_column_type = "ALTER COLUMN %(column)s TYPE %(type)s"
self.sql_alter_column_type = 'ALTER COLUMN %(column)s TYPE %(type)s'
# Cast when data type changed.
using_sql = " USING %(column)s::%(type)s"
using_sql = ' USING %(column)s::%(type)s'
new_internal_type = new_field.get_internal_type()
old_internal_type = old_field.get_internal_type()
if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
if new_internal_type == 'ArrayField' and new_internal_type == old_internal_type:
# Compare base data types for array fields.
if list(self._field_base_data_types(old_field)) != list(
self._field_base_data_types(new_field)
):
if list(self._field_base_data_types(old_field)) != list(self._field_base_data_types(new_field)):
self.sql_alter_column_type += using_sql
elif self._field_data_type(old_field) != self._field_data_type(new_field):
self.sql_alter_column_type += using_sql
# Make ALTER TYPE with SERIAL make sense.
table = strip_quotes(model._meta.db_table)
serial_fields_map = {
"bigserial": "bigint",
"serial": "integer",
"smallserial": "smallint",
}
serial_fields_map = {'bigserial': 'bigint', 'serial': 'integer', 'smallserial': 'smallint'}
if new_type.lower() in serial_fields_map:
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
return (
(
self.sql_alter_column_type
% {
self.sql_alter_column_type % {
"column": self.quote_name(column),
"type": serial_fields_map[new_type.lower()],
},
@@ -137,35 +126,29 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
),
[
(
self.sql_delete_sequence
% {
self.sql_delete_sequence % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_create_sequence
% {
self.sql_create_sequence % {
"sequence": self.quote_name(sequence_name),
},
[],
),
(
self.sql_alter_column
% {
self.sql_alter_column % {
"table": self.quote_name(table),
"changes": self.sql_alter_column_default
% {
"changes": self.sql_alter_column_default % {
"column": self.quote_name(column),
"default": "nextval('%s')"
% self.quote_name(sequence_name),
},
"default": "nextval('%s')" % self.quote_name(sequence_name),
}
},
[],
),
(
self.sql_set_sequence_max
% {
self.sql_set_sequence_max % {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
@@ -173,31 +156,24 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
[],
),
(
self.sql_set_sequence_owner
% {
"table": self.quote_name(table),
"column": self.quote_name(column),
"sequence": self.quote_name(sequence_name),
self.sql_set_sequence_owner % {
'table': self.quote_name(table),
'column': self.quote_name(column),
'sequence': self.quote_name(sequence_name),
},
[],
),
],
)
elif (
old_field.db_parameters(connection=self.connection)["type"]
in serial_fields_map
):
elif old_field.db_parameters(connection=self.connection)['type'] in serial_fields_map:
# Drop the sequence if migrating away from AutoField.
column = strip_quotes(new_field.column)
sequence_name = "%s_%s_seq" % (table, column)
fragment, _ = super()._alter_column_type_sql(
model, old_field, new_field, new_type
)
sequence_name = '%s_%s_seq' % (table, column)
fragment, _ = super()._alter_column_type_sql(model, old_field, new_field, new_type)
return fragment, [
(
self.sql_delete_sequence
% {
"sequence": self.quote_name(sequence_name),
self.sql_delete_sequence % {
'sequence': self.quote_name(sequence_name),
},
[],
),
@@ -205,114 +181,58 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
else:
return super()._alter_column_type_sql(model, old_field, new_field, new_type)
def _alter_field(
self,
model,
old_field,
new_field,
old_type,
new_type,
old_db_params,
new_db_params,
strict=False,
):
def _alter_field(self, model, old_field, new_field, old_type, new_type,
old_db_params, new_db_params, strict=False):
# Drop indexes on varchar/text/citext columns that are changing to a
# different type.
if (old_field.db_index or old_field.unique) and (
(old_type.startswith("varchar") and not new_type.startswith("varchar"))
or (old_type.startswith("text") and not new_type.startswith("text"))
or (old_type.startswith("citext") and not new_type.startswith("citext"))
(old_type.startswith('varchar') and not new_type.startswith('varchar')) or
(old_type.startswith('text') and not new_type.startswith('text')) or
(old_type.startswith('citext') and not new_type.startswith('citext'))
):
index_name = self._create_index_name(
model._meta.db_table, [old_field.column], suffix="_like"
)
index_name = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
self.execute(self._delete_index_sql(model, index_name))
super()._alter_field(
model,
old_field,
new_field,
old_type,
new_type,
old_db_params,
new_db_params,
strict,
model, old_field, new_field, old_type, new_type, old_db_params,
new_db_params, strict,
)
# Added an index? Create any PostgreSQL-specific indexes.
if (not (old_field.db_index or old_field.unique) and new_field.db_index) or (
not old_field.unique and new_field.unique
):
if ((not (old_field.db_index or old_field.unique) and new_field.db_index) or
(not old_field.unique and new_field.unique)):
like_index_statement = self._create_like_index_sql(model, new_field)
if like_index_statement is not None:
self.execute(like_index_statement)
# Removed an index? Drop any PostgreSQL-specific indexes.
if old_field.unique and not (new_field.db_index or new_field.unique):
index_to_remove = self._create_index_name(
model._meta.db_table, [old_field.column], suffix="_like"
)
index_to_remove = self._create_index_name(model._meta.db_table, [old_field.column], suffix='_like')
self.execute(self._delete_index_sql(model, index_to_remove))
def _index_columns(self, table, columns, col_suffixes, opclasses):
if opclasses:
return IndexColumns(
table,
columns,
self.quote_name,
col_suffixes=col_suffixes,
opclasses=opclasses,
)
return IndexColumns(table, columns, self.quote_name, col_suffixes=col_suffixes, opclasses=opclasses)
return super()._index_columns(table, columns, col_suffixes, opclasses)
def add_index(self, model, index, concurrently=False):
self.execute(
index.create_sql(model, self, concurrently=concurrently), params=None
)
self.execute(index.create_sql(model, self, concurrently=concurrently), params=None)
def remove_index(self, model, index, concurrently=False):
self.execute(index.remove_sql(model, self, concurrently=concurrently))
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
sql = (
self.sql_delete_index_concurrently
if concurrently
else self.sql_delete_index
)
sql = self.sql_delete_index_concurrently if concurrently else self.sql_delete_index
return super()._delete_index_sql(model, name, sql)
def _create_index_sql(
self,
model,
*,
fields=None,
name=None,
suffix="",
using="",
db_tablespace=None,
col_suffixes=(),
sql=None,
opclasses=(),
condition=None,
concurrently=False,
include=None,
expressions=None,
self, model, *, fields=None, name=None, suffix='', using='',
db_tablespace=None, col_suffixes=(), sql=None, opclasses=(),
condition=None, concurrently=False, include=None, expressions=None,
):
sql = (
self.sql_create_index
if not concurrently
else self.sql_create_index_concurrently
)
sql = self.sql_create_index if not concurrently else self.sql_create_index_concurrently
return super()._create_index_sql(
model,
fields=fields,
name=name,
suffix=suffix,
using=using,
db_tablespace=db_tablespace,
col_suffixes=col_suffixes,
sql=sql,
opclasses=opclasses,
condition=condition,
include=include,
model, fields=fields, name=name, suffix=suffix, using=using,
db_tablespace=db_tablespace, col_suffixes=col_suffixes, sql=sql,
opclasses=opclasses, condition=condition, include=include,
expressions=expressions,
)
@@ -14,15 +14,18 @@ import warnings
from itertools import chain
from sqlite3 import dbapi2 as Database
import pytz
from django.core.exceptions import ImproperlyConfigured
from django.db import IntegrityError
from django.db.backends import utils as backend_utils
from django.db.backends.base.base import BaseDatabaseWrapper, timezone_constructor
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils import timezone
from django.utils.asyncio import async_unsafe
from django.utils.dateparse import parse_datetime, parse_time
from django.utils.duration import duration_microseconds
from django.utils.regex_helper import _lazy_re_compile
from django.utils.version import PY38
from .client import DatabaseClient
from .creation import DatabaseCreation
@@ -46,11 +49,9 @@ def none_guard(func):
are NULL. This decorator simplifies the implementation of this for the
custom functions registered below.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
return None if None in args else func(*args, **kwargs)
return wrapper
@@ -59,19 +60,19 @@ def list_aggregate(function):
Return an aggregate class that accumulates values in a list and applies
the provided function to the data.
"""
return type("ListAggregate", (list,), {"finalize": function, "step": list.append})
return type('ListAggregate', (list,), {'finalize': function, 'step': list.append})
def check_sqlite_version():
if Database.sqlite_version_info < (3, 9, 0):
raise ImproperlyConfigured(
"SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version
'SQLite 3.9.0 or later is required (found %s).' % Database.sqlite_version
)
check_sqlite_version()
Database.register_converter("bool", b"1".__eq__)
Database.register_converter("bool", b'1'.__eq__)
Database.register_converter("time", decoder(parse_time))
Database.register_converter("datetime", decoder(parse_datetime))
Database.register_converter("timestamp", decoder(parse_datetime))
@@ -80,69 +81,70 @@ Database.register_adapter(decimal.Decimal, str)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "sqlite"
display_name = "SQLite"
vendor = 'sqlite'
display_name = 'SQLite'
# SQLite doesn't actually support most of these types, but it "does the right
# thing" given more verbose field definitions, so leave them as is so that
# schema inspection is more useful.
data_types = {
"AutoField": "integer",
"BigAutoField": "integer",
"BinaryField": "BLOB",
"BooleanField": "bool",
"CharField": "varchar(%(max_length)s)",
"DateField": "date",
"DateTimeField": "datetime",
"DecimalField": "decimal",
"DurationField": "bigint",
"FileField": "varchar(%(max_length)s)",
"FilePathField": "varchar(%(max_length)s)",
"FloatField": "real",
"IntegerField": "integer",
"BigIntegerField": "bigint",
"IPAddressField": "char(15)",
"GenericIPAddressField": "char(39)",
"JSONField": "text",
"OneToOneField": "integer",
"PositiveBigIntegerField": "bigint unsigned",
"PositiveIntegerField": "integer unsigned",
"PositiveSmallIntegerField": "smallint unsigned",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "integer",
"SmallIntegerField": "smallint",
"TextField": "text",
"TimeField": "time",
"UUIDField": "char(32)",
'AutoField': 'integer',
'BigAutoField': 'integer',
'BinaryField': 'BLOB',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime',
'DecimalField': 'decimal',
'DurationField': 'bigint',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'real',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'char(15)',
'GenericIPAddressField': 'char(39)',
'JSONField': 'text',
'NullBooleanField': 'bool',
'OneToOneField': 'integer',
'PositiveBigIntegerField': 'bigint unsigned',
'PositiveIntegerField': 'integer unsigned',
'PositiveSmallIntegerField': 'smallint unsigned',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'integer',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'UUIDField': 'char(32)',
}
data_type_check_constraints = {
"PositiveBigIntegerField": '"%(column)s" >= 0',
"JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
"PositiveIntegerField": '"%(column)s" >= 0',
"PositiveSmallIntegerField": '"%(column)s" >= 0',
'PositiveBigIntegerField': '"%(column)s" >= 0',
'JSONField': '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
'PositiveIntegerField': '"%(column)s" >= 0',
'PositiveSmallIntegerField': '"%(column)s" >= 0',
}
data_types_suffix = {
"AutoField": "AUTOINCREMENT",
"BigAutoField": "AUTOINCREMENT",
"SmallAutoField": "AUTOINCREMENT",
'AutoField': 'AUTOINCREMENT',
'BigAutoField': 'AUTOINCREMENT',
'SmallAutoField': 'AUTOINCREMENT',
}
# SQLite requires LIKE statements to include an ESCAPE clause if the value
# being escaped has a percent or underscore in it.
# See https://www.sqlite.org/lang_expr.html for an explanation.
operators = {
"exact": "= %s",
"iexact": "LIKE %s ESCAPE '\\'",
"contains": "LIKE %s ESCAPE '\\'",
"icontains": "LIKE %s ESCAPE '\\'",
"regex": "REGEXP %s",
"iregex": "REGEXP '(?i)' || %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": "LIKE %s ESCAPE '\\'",
"endswith": "LIKE %s ESCAPE '\\'",
"istartswith": "LIKE %s ESCAPE '\\'",
"iendswith": "LIKE %s ESCAPE '\\'",
'exact': '= %s',
'iexact': "LIKE %s ESCAPE '\\'",
'contains': "LIKE %s ESCAPE '\\'",
'icontains': "LIKE %s ESCAPE '\\'",
'regex': 'REGEXP %s',
'iregex': "REGEXP '(?i)' || %s",
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': "LIKE %s ESCAPE '\\'",
'endswith': "LIKE %s ESCAPE '\\'",
'istartswith': "LIKE %s ESCAPE '\\'",
'iendswith': "LIKE %s ESCAPE '\\'",
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -155,12 +157,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
"contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
"icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
"startswith": r"LIKE {} || '%%' ESCAPE '\'",
"istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
"endswith": r"LIKE '%%' || {} ESCAPE '\'",
"iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
'contains': r"LIKE '%%' || {} || '%%' ESCAPE '\'",
'icontains': r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
'startswith': r"LIKE {} || '%%' ESCAPE '\'",
'istartswith': r"LIKE UPPER({}) || '%%' ESCAPE '\'",
'endswith': r"LIKE '%%' || {} ESCAPE '\'",
'iendswith': r"LIKE '%%' || UPPER({}) ESCAPE '\'",
}
Database = Database
@@ -174,15 +176,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def get_connection_params(self):
settings_dict = self.settings_dict
if not settings_dict["NAME"]:
if not settings_dict['NAME']:
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME value."
)
"Please supply the NAME value.")
kwargs = {
"database": settings_dict["NAME"],
"detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
**settings_dict["OPTIONS"],
# TODO: Remove str() when dropping support for PY36.
# https://bugs.python.org/issue33496
'database': str(settings_dict['NAME']),
'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
**settings_dict['OPTIONS'],
}
# Always allow the underlying SQLite connection to be shareable
# between multiple threads. The safe-guarding will be handled at a
@@ -190,103 +193,78 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# property. This is necessary as the shareability is disabled by
# default in pysqlite and it cannot be changed once a connection is
# opened.
if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
warnings.warn(
"The `check_same_thread` option was provided and set to "
"True. It will be overridden with False. Use the "
"`DatabaseWrapper.allow_thread_sharing` property instead "
"for controlling thread shareability.",
RuntimeWarning,
'The `check_same_thread` option was provided and set to '
'True. It will be overridden with False. Use the '
'`DatabaseWrapper.allow_thread_sharing` property instead '
'for controlling thread shareability.',
RuntimeWarning
)
kwargs.update({"check_same_thread": False, "uri": True})
kwargs.update({'check_same_thread': False, 'uri': True})
return kwargs
@async_unsafe
def get_new_connection(self, conn_params):
conn = Database.connect(**conn_params)
create_deterministic_function = functools.partial(
conn.create_function,
deterministic=True,
)
create_deterministic_function(
"django_date_extract", 2, _sqlite_datetime_extract
)
create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc)
create_deterministic_function(
"django_datetime_cast_date", 3, _sqlite_datetime_cast_date
)
create_deterministic_function(
"django_datetime_cast_time", 3, _sqlite_datetime_cast_time
)
create_deterministic_function(
"django_datetime_extract", 4, _sqlite_datetime_extract
)
create_deterministic_function(
"django_datetime_trunc", 4, _sqlite_datetime_trunc
)
create_deterministic_function("django_time_extract", 2, _sqlite_time_extract)
create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc)
create_deterministic_function("django_time_diff", 2, _sqlite_time_diff)
create_deterministic_function(
"django_timestamp_diff", 2, _sqlite_timestamp_diff
)
create_deterministic_function(
"django_format_dtdelta", 3, _sqlite_format_dtdelta
)
create_deterministic_function("regexp", 2, _sqlite_regexp)
create_deterministic_function("ACOS", 1, none_guard(math.acos))
create_deterministic_function("ASIN", 1, none_guard(math.asin))
create_deterministic_function("ATAN", 1, none_guard(math.atan))
create_deterministic_function("ATAN2", 2, none_guard(math.atan2))
create_deterministic_function("BITXOR", 2, none_guard(operator.xor))
create_deterministic_function("CEILING", 1, none_guard(math.ceil))
create_deterministic_function("COS", 1, none_guard(math.cos))
create_deterministic_function("COT", 1, none_guard(lambda x: 1 / math.tan(x)))
create_deterministic_function("DEGREES", 1, none_guard(math.degrees))
create_deterministic_function("EXP", 1, none_guard(math.exp))
create_deterministic_function("FLOOR", 1, none_guard(math.floor))
create_deterministic_function("LN", 1, none_guard(math.log))
create_deterministic_function("LOG", 2, none_guard(lambda x, y: math.log(y, x)))
create_deterministic_function("LPAD", 3, _sqlite_lpad)
create_deterministic_function(
"MD5", 1, none_guard(lambda x: hashlib.md5(x.encode()).hexdigest())
)
create_deterministic_function("MOD", 2, none_guard(math.fmod))
create_deterministic_function("PI", 0, lambda: math.pi)
create_deterministic_function("POWER", 2, none_guard(operator.pow))
create_deterministic_function("RADIANS", 1, none_guard(math.radians))
create_deterministic_function("REPEAT", 2, none_guard(operator.mul))
create_deterministic_function("REVERSE", 1, none_guard(lambda x: x[::-1]))
create_deterministic_function("RPAD", 3, _sqlite_rpad)
create_deterministic_function(
"SHA1", 1, none_guard(lambda x: hashlib.sha1(x.encode()).hexdigest())
)
create_deterministic_function(
"SHA224", 1, none_guard(lambda x: hashlib.sha224(x.encode()).hexdigest())
)
create_deterministic_function(
"SHA256", 1, none_guard(lambda x: hashlib.sha256(x.encode()).hexdigest())
)
create_deterministic_function(
"SHA384", 1, none_guard(lambda x: hashlib.sha384(x.encode()).hexdigest())
)
create_deterministic_function(
"SHA512", 1, none_guard(lambda x: hashlib.sha512(x.encode()).hexdigest())
)
create_deterministic_function(
"SIGN", 1, none_guard(lambda x: (x > 0) - (x < 0))
)
create_deterministic_function("SIN", 1, none_guard(math.sin))
create_deterministic_function("SQRT", 1, none_guard(math.sqrt))
create_deterministic_function("TAN", 1, none_guard(math.tan))
if PY38:
create_deterministic_function = functools.partial(
conn.create_function,
deterministic=True,
)
else:
create_deterministic_function = conn.create_function
create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract)
create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc)
create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract)
create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
create_deterministic_function('django_time_extract', 2, _sqlite_time_extract)
create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc)
create_deterministic_function('django_time_diff', 2, _sqlite_time_diff)
create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff)
create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta)
create_deterministic_function('regexp', 2, _sqlite_regexp)
create_deterministic_function('ACOS', 1, none_guard(math.acos))
create_deterministic_function('ASIN', 1, none_guard(math.asin))
create_deterministic_function('ATAN', 1, none_guard(math.atan))
create_deterministic_function('ATAN2', 2, none_guard(math.atan2))
create_deterministic_function('BITXOR', 2, none_guard(operator.xor))
create_deterministic_function('CEILING', 1, none_guard(math.ceil))
create_deterministic_function('COS', 1, none_guard(math.cos))
create_deterministic_function('COT', 1, none_guard(lambda x: 1 / math.tan(x)))
create_deterministic_function('DEGREES', 1, none_guard(math.degrees))
create_deterministic_function('EXP', 1, none_guard(math.exp))
create_deterministic_function('FLOOR', 1, none_guard(math.floor))
create_deterministic_function('LN', 1, none_guard(math.log))
create_deterministic_function('LOG', 2, none_guard(lambda x, y: math.log(y, x)))
create_deterministic_function('LPAD', 3, _sqlite_lpad)
create_deterministic_function('MD5', 1, none_guard(lambda x: hashlib.md5(x.encode()).hexdigest()))
create_deterministic_function('MOD', 2, none_guard(math.fmod))
create_deterministic_function('PI', 0, lambda: math.pi)
create_deterministic_function('POWER', 2, none_guard(operator.pow))
create_deterministic_function('RADIANS', 1, none_guard(math.radians))
create_deterministic_function('REPEAT', 2, none_guard(operator.mul))
create_deterministic_function('REVERSE', 1, none_guard(lambda x: x[::-1]))
create_deterministic_function('RPAD', 3, _sqlite_rpad)
create_deterministic_function('SHA1', 1, none_guard(lambda x: hashlib.sha1(x.encode()).hexdigest()))
create_deterministic_function('SHA224', 1, none_guard(lambda x: hashlib.sha224(x.encode()).hexdigest()))
create_deterministic_function('SHA256', 1, none_guard(lambda x: hashlib.sha256(x.encode()).hexdigest()))
create_deterministic_function('SHA384', 1, none_guard(lambda x: hashlib.sha384(x.encode()).hexdigest()))
create_deterministic_function('SHA512', 1, none_guard(lambda x: hashlib.sha512(x.encode()).hexdigest()))
create_deterministic_function('SIGN', 1, none_guard(lambda x: (x > 0) - (x < 0)))
create_deterministic_function('SIN', 1, none_guard(math.sin))
create_deterministic_function('SQRT', 1, none_guard(math.sqrt))
create_deterministic_function('TAN', 1, none_guard(math.tan))
# Don't use the built-in RANDOM() function because it returns a value
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
conn.create_function("RAND", 0, random.random)
conn.create_aggregate("STDDEV_POP", 1, list_aggregate(statistics.pstdev))
conn.create_aggregate("STDDEV_SAMP", 1, list_aggregate(statistics.stdev))
conn.create_aggregate("VAR_POP", 1, list_aggregate(statistics.pvariance))
conn.create_aggregate("VAR_SAMP", 1, list_aggregate(statistics.variance))
conn.execute("PRAGMA foreign_keys = ON")
# in the range [2^63, 2^63 - 1] instead of [0, 1).
conn.create_function('RAND', 0, random.random)
conn.create_aggregate('STDDEV_POP', 1, list_aggregate(statistics.pstdev))
conn.create_aggregate('STDDEV_SAMP', 1, list_aggregate(statistics.stdev))
conn.create_aggregate('VAR_POP', 1, list_aggregate(statistics.pvariance))
conn.create_aggregate('VAR_SAMP', 1, list_aggregate(statistics.variance))
conn.execute('PRAGMA foreign_keys = ON')
return conn
def init_connection_state(self):
@@ -318,7 +296,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else:
# sqlite3's internal default is ''. It's different from None.
# See Modules/_sqlite/connection.c.
level = ""
level = ''
# 'isolation_level' is a misleading API.
# SQLite always runs at the SERIALIZABLE isolation level.
with self.wrap_database_errors:
@@ -326,16 +304,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def disable_constraint_checking(self):
with self.cursor() as cursor:
cursor.execute("PRAGMA foreign_keys = OFF")
cursor.execute('PRAGMA foreign_keys = OFF')
# Foreign key constraints cannot be turned off while in a multi-
# statement transaction. Fetch the current state of the pragma
# to determine if constraints are effectively disabled.
enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
enabled = cursor.execute('PRAGMA foreign_keys').fetchone()[0]
return not bool(enabled)
def enable_constraint_checking(self):
with self.cursor() as cursor:
cursor.execute("PRAGMA foreign_keys = ON")
cursor.execute('PRAGMA foreign_keys = ON')
def check_constraints(self, table_names=None):
"""
@@ -348,32 +326,24 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if self.features.supports_pragma_foreign_key_check:
with self.cursor() as cursor:
if table_names is None:
violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
violations = cursor.execute('PRAGMA foreign_key_check').fetchall()
else:
violations = chain.from_iterable(
cursor.execute(
"PRAGMA foreign_key_check(%s)"
'PRAGMA foreign_key_check(%s)'
% self.ops.quote_name(table_name)
).fetchall()
for table_name in table_names
)
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
for (
table_name,
rowid,
referenced_table_name,
foreign_key_index,
) in violations:
for table_name, rowid, referenced_table_name, foreign_key_index in violations:
foreign_key = cursor.execute(
"PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name)
'PRAGMA foreign_key_list(%s)' % self.ops.quote_name(table_name)
).fetchall()[foreign_key_index]
column_name, referenced_column_name = foreign_key[3:5]
primary_key_column_name = self.introspection.get_primary_key_column(
cursor, table_name
)
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
primary_key_value, bad_value = cursor.execute(
"SELECT %s, %s FROM %s WHERE rowid = %%s"
% (
'SELECT %s, %s FROM %s WHERE rowid = %%s' % (
self.ops.quote_name(primary_key_column_name),
self.ops.quote_name(column_name),
self.ops.quote_name(table_name),
@@ -383,15 +353,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s."
% (
table_name,
primary_key_value,
table_name,
column_name,
bad_value,
referenced_table_name,
referenced_column_name,
"does not have a corresponding value in %s.%s." % (
table_name, primary_key_value, table_name, column_name,
bad_value, referenced_table_name, referenced_column_name
)
)
else:
@@ -399,17 +363,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(
cursor, table_name
)
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for (
column_name,
referenced_table_name,
referenced_column_name,
) in key_columns:
for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
@@ -418,29 +376,18 @@ class DatabaseWrapper(BaseDatabaseWrapper):
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
"""
% (
primary_key_column_name,
column_name,
table_name,
referenced_table_name,
column_name,
referenced_column_name,
column_name,
referenced_column_name,
primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s."
% (
table_name,
bad_row[0],
table_name,
column_name,
bad_row[1],
referenced_table_name,
referenced_column_name,
"does not have a corresponding value in %s.%s." % (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
@@ -457,10 +404,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.cursor().execute("BEGIN")
def is_in_memory_db(self):
return self.creation.is_in_memory_db(self.settings_dict["NAME"])
return self.creation.is_in_memory_db(self.settings_dict['NAME'])
FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
FORMAT_QMARK_REGEX = _lazy_re_compile(r'(?<!%)%s')
class SQLiteCursorWrapper(Database.Cursor):
@@ -469,7 +416,6 @@ class SQLiteCursorWrapper(Database.Cursor):
This fixes it -- but note that if you want to use a literal "%s" in a query,
you'll need to use "%%s".
"""
def execute(self, query, params=None):
if params is None:
return Database.Cursor.execute(self, query)
@@ -481,7 +427,7 @@ class SQLiteCursorWrapper(Database.Cursor):
return Database.Cursor.executemany(self, query, param_list)
def convert_query(self, query):
return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%')
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
@@ -492,14 +438,17 @@ def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
except (TypeError, ValueError):
return None
if conn_tzname:
dt = dt.replace(tzinfo=timezone_constructor(conn_tzname))
dt = dt.replace(tzinfo=pytz.timezone(conn_tzname))
if tzname is not None and tzname != conn_tzname:
tzname, sign, offset = backend_utils.split_tzname_delta(tzname)
if offset:
hours, minutes = offset.split(":")
offset_delta = datetime.timedelta(hours=int(hours), minutes=int(minutes))
dt += offset_delta if sign == "+" else -offset_delta
dt = timezone.localtime(dt, timezone_constructor(tzname))
sign_index = tzname.find('+') + tzname.find('-') + 1
if sign_index > -1:
sign = tzname[sign_index]
tzname, offset = tzname.split(sign)
if offset:
hours, minutes = offset.split(':')
offset_delta = datetime.timedelta(hours=int(hours), minutes=int(minutes))
dt += offset_delta if sign == '+' else -offset_delta
dt = timezone.localtime(dt, pytz.timezone(tzname))
return dt
@@ -507,17 +456,17 @@ def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "year":
if lookup_type == 'year':
return "%i-01-01" % dt.year
elif lookup_type == "quarter":
elif lookup_type == 'quarter':
month_in_quarter = dt.month - (dt.month - 1) % 3
return "%i-%02i-01" % (dt.year, month_in_quarter)
elif lookup_type == "month":
return '%i-%02i-01' % (dt.year, month_in_quarter)
elif lookup_type == 'month':
return "%i-%02i-01" % (dt.year, dt.month)
elif lookup_type == "week":
elif lookup_type == 'week':
dt = dt - datetime.timedelta(days=dt.weekday())
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
elif lookup_type == "day":
elif lookup_type == 'day':
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
@@ -532,11 +481,11 @@ def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
return None
else:
dt = dt_parsed
if lookup_type == "hour":
if lookup_type == 'hour':
return "%02i:00:00" % dt.hour
elif lookup_type == "minute":
elif lookup_type == 'minute':
return "%02i:%02i:00" % (dt.hour, dt.minute)
elif lookup_type == "second":
elif lookup_type == 'second':
return "%02i:%02i:%02i" % (dt.hour, dt.minute, dt.second)
@@ -558,15 +507,15 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "week_day":
if lookup_type == 'week_day':
return (dt.isoweekday() % 7) + 1
elif lookup_type == "iso_week_day":
elif lookup_type == 'iso_week_day':
return dt.isoweekday()
elif lookup_type == "week":
elif lookup_type == 'week':
return dt.isocalendar()[1]
elif lookup_type == "quarter":
elif lookup_type == 'quarter':
return math.ceil(dt.month / 3)
elif lookup_type == "iso_year":
elif lookup_type == 'iso_year':
return dt.isocalendar()[0]
else:
return getattr(dt, lookup_type)
@@ -576,37 +525,24 @@ def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "year":
if lookup_type == 'year':
return "%i-01-01 00:00:00" % dt.year
elif lookup_type == "quarter":
elif lookup_type == 'quarter':
month_in_quarter = dt.month - (dt.month - 1) % 3
return "%i-%02i-01 00:00:00" % (dt.year, month_in_quarter)
elif lookup_type == "month":
return '%i-%02i-01 00:00:00' % (dt.year, month_in_quarter)
elif lookup_type == 'month':
return "%i-%02i-01 00:00:00" % (dt.year, dt.month)
elif lookup_type == "week":
elif lookup_type == 'week':
dt = dt - datetime.timedelta(days=dt.weekday())
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
elif lookup_type == "day":
elif lookup_type == 'day':
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
elif lookup_type == "hour":
elif lookup_type == 'hour':
return "%i-%02i-%02i %02i:00:00" % (dt.year, dt.month, dt.day, dt.hour)
elif lookup_type == "minute":
return "%i-%02i-%02i %02i:%02i:00" % (
dt.year,
dt.month,
dt.day,
dt.hour,
dt.minute,
)
elif lookup_type == "second":
return "%i-%02i-%02i %02i:%02i:%02i" % (
dt.year,
dt.month,
dt.day,
dt.hour,
dt.minute,
dt.second,
)
elif lookup_type == 'minute':
return "%i-%02i-%02i %02i:%02i:00" % (dt.year, dt.month, dt.day, dt.hour, dt.minute)
elif lookup_type == 'second':
return "%i-%02i-%02i %02i:%02i:%02i" % (dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
def _sqlite_time_extract(lookup_type, dt):
@@ -619,40 +555,25 @@ def _sqlite_time_extract(lookup_type, dt):
return getattr(dt, lookup_type)
def _sqlite_prepare_dtdelta_param(conn, param):
if conn in ["+", "-"]:
if isinstance(param, int):
return datetime.timedelta(0, 0, param)
else:
return backend_utils.typecast_timestamp(param)
return param
@none_guard
def _sqlite_format_dtdelta(conn, lhs, rhs):
"""
LHS and RHS can be either:
- An integer number of microseconds
- A string representing a datetime
- A scalar value, e.g. float
"""
conn = conn.strip()
try:
real_lhs = _sqlite_prepare_dtdelta_param(conn, lhs)
real_rhs = _sqlite_prepare_dtdelta_param(conn, rhs)
real_lhs = datetime.timedelta(0, 0, lhs) if isinstance(lhs, int) else backend_utils.typecast_timestamp(lhs)
real_rhs = datetime.timedelta(0, 0, rhs) if isinstance(rhs, int) else backend_utils.typecast_timestamp(rhs)
if conn.strip() == '+':
out = real_lhs + real_rhs
else:
out = real_lhs - real_rhs
except (ValueError, TypeError):
return None
if conn == "+":
# typecast_timestamp returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
out = str(real_lhs + real_rhs)
elif conn == "-":
out = str(real_lhs - real_rhs)
elif conn == "*":
out = real_lhs * real_rhs
else:
out = real_lhs / real_rhs
return out
# typecast_timestamp returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
return str(out)
@none_guard
@@ -660,14 +581,14 @@ def _sqlite_time_diff(lhs, rhs):
left = backend_utils.typecast_time(lhs)
right = backend_utils.typecast_time(rhs)
return (
(left.hour * 60 * 60 * 1000000)
+ (left.minute * 60 * 1000000)
+ (left.second * 1000000)
+ (left.microsecond)
- (right.hour * 60 * 60 * 1000000)
- (right.minute * 60 * 1000000)
- (right.second * 1000000)
- (right.microsecond)
(left.hour * 60 * 60 * 1000000) +
(left.minute * 60 * 1000000) +
(left.second * 1000000) +
(left.microsecond) -
(right.hour * 60 * 60 * 1000000) -
(right.minute * 60 * 1000000) -
(right.second * 1000000) -
(right.microsecond)
)
@@ -687,7 +608,7 @@ def _sqlite_regexp(re_pattern, re_string):
def _sqlite_lpad(text, length, fill_text):
if len(text) >= length:
return text[:length]
return (fill_text * length)[: length - len(text)] + text
return (fill_text * length)[:length - len(text)] + text
@none_guard
@@ -2,9 +2,15 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "sqlite3"
executable_name = 'sqlite3'
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name, settings_dict["NAME"], *parameters]
args = [
cls.executable_name,
# TODO: Remove str() when dropping support for PY37. args
# parameter accepts path-like objects on Windows since Python 3.8.
str(settings_dict['NAME']),
*parameters,
]
return args, None
@@ -7,16 +7,17 @@ from django.db.backends.base.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation):
@staticmethod
def is_in_memory_db(database_name):
return not isinstance(database_name, Path) and (
database_name == ":memory:" or "mode=memory" in database_name
database_name == ':memory:' or 'mode=memory' in database_name
)
def _get_test_db_name(self):
test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:"
if test_database_name == ":memory:":
return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias
test_database_name = self.connection.settings_dict['TEST']['NAME'] or ':memory:'
if test_database_name == ':memory:':
return 'file:memorydb_%s?mode=memory&cache=shared' % self.connection.alias
return test_database_name
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
@@ -27,39 +28,38 @@ class DatabaseCreation(BaseDatabaseCreation):
if not self.is_in_memory_db(test_database_name):
# Erase the old test database
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (self._get_database_display_str(verbosity, test_database_name),)
)
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, test_database_name),
))
if os.access(test_database_name, os.F_OK):
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name
)
if autoclobber or confirm == "yes":
if autoclobber or confirm == 'yes':
try:
os.remove(test_database_name)
except Exception as e:
self.log("Got an error deleting the old test database: %s" % e)
self.log('Got an error deleting the old test database: %s' % e)
sys.exit(2)
else:
self.log("Tests cancelled.")
self.log('Tests cancelled.')
sys.exit(1)
return test_database_name
def get_test_db_clone_settings(self, suffix):
orig_settings_dict = self.connection.settings_dict
source_database_name = orig_settings_dict["NAME"]
source_database_name = orig_settings_dict['NAME']
if self.is_in_memory_db(source_database_name):
return orig_settings_dict
else:
root, ext = os.path.splitext(orig_settings_dict["NAME"])
return {**orig_settings_dict, "NAME": "{}_{}{}".format(root, suffix, ext)}
root, ext = os.path.splitext(orig_settings_dict['NAME'])
return {**orig_settings_dict, 'NAME': '{}_{}.{}'.format(root, suffix, ext)}
def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
# Forking automatically makes a copy of an in-memory database.
if not self.is_in_memory_db(source_database_name):
# Erase the old test database
@@ -67,23 +67,18 @@ class DatabaseCreation(BaseDatabaseCreation):
if keepdb:
return
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, target_database_name
),
)
)
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
try:
os.remove(target_database_name)
except Exception as e:
self.log("Got an error deleting the old test database: %s" % e)
self.log('Got an error deleting the old test database: %s' % e)
sys.exit(2)
try:
shutil.copy(source_database_name, target_database_name)
except Exception as e:
self.log("Got an error cloning the test database: %s" % e)
self.log('Got an error cloning the test database: %s' % e)
sys.exit(2)
def _destroy_test_db(self, test_database_name, verbosity):
@@ -100,7 +95,7 @@ class DatabaseCreation(BaseDatabaseCreation):
TEST NAME. See https://www.sqlite.org/inmemorydb.html
"""
test_database_name = self._get_test_db_name()
sig = [self.connection.settings_dict["NAME"]]
sig = [self.connection.settings_dict['NAME']]
if self.is_in_memory_db(test_database_name):
sig.append(self.connection.alias)
else:
@@ -45,82 +45,58 @@ class DatabaseFeatures(BaseDatabaseFeatures):
order_by_nulls_first = True
supports_json_field_contains = False
test_collations = {
"ci": "nocase",
"cs": "binary",
"non_default": "nocase",
'ci': 'nocase',
'cs': 'binary',
'non_default': 'nocase',
}
@cached_property
def django_test_skips(self):
skips = {
"SQLite stores values rounded to 15 significant digits.": {
"model_fields.test_decimalfield.DecimalFieldTests."
"test_fetch_from_db_without_float_rounding",
'SQLite stores values rounded to 15 significant digits.': {
'model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding',
},
"SQLite naively remakes the table on field alteration.": {
"schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops",
"schema.tests.SchemaTests.test_unique_and_reverse_m2m",
"schema.tests.SchemaTests."
"test_alter_field_default_doesnt_perform_queries",
"schema.tests.SchemaTests."
"test_rename_column_renames_deferred_sql_references",
'SQLite naively remakes the table on field alteration.': {
'schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops',
'schema.tests.SchemaTests.test_unique_and_reverse_m2m',
'schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries',
'schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references',
},
"SQLite doesn't have a constraint.": {
"model_fields.test_integerfield.PositiveIntegerFieldTests."
"test_negative_values",
},
"SQLite doesn't support negative precision for ROUND().": {
"db_functions.math.test_round.RoundTests."
"test_null_with_negative_precision",
"db_functions.math.test_round.RoundTests."
"test_decimal_with_negative_precision",
"db_functions.math.test_round.RoundTests."
"test_float_with_negative_precision",
"db_functions.math.test_round.RoundTests."
"test_integer_with_negative_precision",
'model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values',
},
}
if Database.sqlite_version_info < (3, 27):
skips.update(
{
"Nondeterministic failure on SQLite < 3.27.": {
"expressions_window.tests.WindowFunctionTests."
"test_subquery_row_range_rank",
},
}
)
skips.update({
'Nondeterministic failure on SQLite < 3.27.': {
'expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank',
},
})
if self.connection.is_in_memory_db():
skips.update(
{
"the sqlite backend's close() method is a no-op when using an "
"in-memory database": {
"servers.test_liveserverthread.LiveServerThreadTest."
"test_closes_connections",
"servers.tests.LiveServerTestCloseConnectionTest."
"test_closes_connections",
},
}
)
skips.update({
"the sqlite backend's close() method is a no-op when using an "
"in-memory database": {
'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections',
},
})
return skips
@cached_property
def supports_atomic_references_rename(self):
# SQLite 3.28.0 bundled with MacOS 10.15 does not support renaming
# references atomically.
if platform.mac_ver()[0].startswith(
"10.15."
) and Database.sqlite_version_info == (3, 28, 0):
if platform.mac_ver()[0].startswith('10.15.') and Database.sqlite_version_info == (3, 28, 0):
return False
return Database.sqlite_version_info >= (3, 26, 0)
@cached_property
def introspected_field_types(self):
return {
return{
**super().introspected_field_types,
"BigAutoField": "AutoField",
"DurationField": "BigIntegerField",
"GenericIPAddressField": "CharField",
"SmallAutoField": "AutoField",
'BigAutoField': 'AutoField',
'DurationField': 'BigIntegerField',
'GenericIPAddressField': 'CharField',
'SmallAutoField': 'AutoField',
}
@cached_property
@@ -133,13 +109,5 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False
return True
can_introspect_json_field = property(operator.attrgetter("supports_json_field"))
has_json_object_function = property(operator.attrgetter("supports_json_field"))
@cached_property
def can_return_columns_from_insert(self):
return Database.sqlite_version_info >= (3, 35)
can_return_rows_from_bulk_insert = property(
operator.attrgetter("can_return_columns_from_insert")
)
can_introspect_json_field = property(operator.attrgetter('supports_json_field'))
has_json_object_function = property(operator.attrgetter('supports_json_field'))
@@ -3,21 +3,19 @@ from collections import namedtuple
import sqlparse
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
from django.db.models import Index
from django.utils.regex_helper import _lazy_re_compile
FieldInfo = namedtuple(
"FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
)
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint'))
field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$')
def get_field_size(name):
"""Extract the size number from a "varchar(11)" type name"""
""" Extract the size number from a "varchar(11)" type name """
m = field_size_re.search(name)
return int(m[1]) if m else None
@@ -30,29 +28,29 @@ class FlexibleFieldLookupDict:
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
base_data_types_reverse = {
"bool": "BooleanField",
"boolean": "BooleanField",
"smallint": "SmallIntegerField",
"smallint unsigned": "PositiveSmallIntegerField",
"smallinteger": "SmallIntegerField",
"int": "IntegerField",
"integer": "IntegerField",
"bigint": "BigIntegerField",
"integer unsigned": "PositiveIntegerField",
"bigint unsigned": "PositiveBigIntegerField",
"decimal": "DecimalField",
"real": "FloatField",
"text": "TextField",
"char": "CharField",
"varchar": "CharField",
"blob": "BinaryField",
"date": "DateField",
"datetime": "DateTimeField",
"time": "TimeField",
'bool': 'BooleanField',
'boolean': 'BooleanField',
'smallint': 'SmallIntegerField',
'smallint unsigned': 'PositiveSmallIntegerField',
'smallinteger': 'SmallIntegerField',
'int': 'IntegerField',
'integer': 'IntegerField',
'bigint': 'BigIntegerField',
'integer unsigned': 'PositiveIntegerField',
'bigint unsigned': 'PositiveBigIntegerField',
'decimal': 'DecimalField',
'real': 'FloatField',
'text': 'TextField',
'char': 'CharField',
'varchar': 'CharField',
'blob': 'BinaryField',
'date': 'DateField',
'datetime': 'DateTimeField',
'time': 'TimeField',
}
def __getitem__(self, key):
key = key.lower().split("(", 1)[0].strip()
key = key.lower().split('(', 1)[0].strip()
return self.base_data_types_reverse[key]
@@ -61,28 +59,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.pk and field_type in {
"BigIntegerField",
"IntegerField",
"SmallIntegerField",
}:
if description.pk and field_type in {'BigIntegerField', 'IntegerField', 'SmallIntegerField'}:
# No support for BigAutoField or SmallAutoField as SQLite treats
# all integer primary keys as signed 64-bit integers.
return "AutoField"
return 'AutoField'
if description.has_json_constraint:
return "JSONField"
return 'JSONField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
cursor.execute(
"""
cursor.execute("""
SELECT name, type FROM sqlite_master
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
ORDER BY name"""
)
ORDER BY name""")
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
@@ -90,9 +82,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
Return a description of the table with the DB-API cursor.description
interface.
"""
cursor.execute(
"PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
)
cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name))
table_info = cursor.fetchall()
collations = self._get_column_collations(cursor, table_name)
json_columns = set()
@@ -100,39 +90,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for line in table_info:
column = line[1]
json_constraint_sql = '%%json_valid("%s")%%' % column
has_json_constraint = cursor.execute(
"""
has_json_constraint = cursor.execute("""
SELECT sql
FROM sqlite_master
WHERE
type = 'table' AND
name = %s AND
sql LIKE %s
""",
[table_name, json_constraint_sql],
).fetchone()
""", [table_name, json_constraint_sql]).fetchone()
if has_json_constraint:
json_columns.add(column)
return [
FieldInfo(
name,
data_type,
None,
get_field_size(data_type),
None,
None,
not notnull,
default,
collations.get(name),
pk == 1,
name in json_columns,
name, data_type, None, get_field_size(data_type), None, None,
not notnull, default, collations.get(name), pk == 1, name in json_columns
)
for cid, name, data_type, notnull, default, pk in table_info
]
def get_sequences(self, cursor, table_name, table_fields=()):
pk_col = self.get_primary_key_column(cursor, table_name)
return [{"table": table_name, "column": pk_col}]
return [{'table': table_name, 'column': pk_col}]
def get_relations(self, cursor, table_name):
"""
@@ -146,18 +124,18 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
cursor.execute(
"SELECT sql, type FROM sqlite_master "
"WHERE tbl_name = %s AND type IN ('table', 'view')",
[table_name],
[table_name]
)
create_sql, table_type = cursor.fetchone()
if table_type == "view":
if table_type == 'view':
# It might be a view, then no results will be returned
return relations
results = create_sql[create_sql.index("(") + 1 : create_sql.rindex(")")]
results = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
# Walk through and look for references to other tables. SQLite doesn't
# really have enforced references, but since it echoes out the SQL used
# to create the table we can look for REFERENCES statements used there.
for field_desc in results.split(","):
for field_desc in results.split(','):
field_desc = field_desc.strip()
if field_desc.startswith("UNIQUE"):
continue
@@ -169,7 +147,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
if field_desc.startswith("FOREIGN KEY"):
# Find name of the target FK field
m = re.match(r"FOREIGN KEY\s*\(([^\)]*)\).*", field_desc, re.I)
m = re.match(r'FOREIGN KEY\s*\(([^\)]*)\).*', field_desc, re.I)
field_name = m[1].strip('"')
else:
field_name = field_desc.split()[0].strip('"')
@@ -177,15 +155,15 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
result = cursor.fetchall()[0]
other_table_results = result[0].strip()
li, ri = other_table_results.index("("), other_table_results.rindex(")")
other_table_results = other_table_results[li + 1 : ri]
li, ri = other_table_results.index('('), other_table_results.rindex(')')
other_table_results = other_table_results[li + 1:ri]
for other_desc in other_table_results.split(","):
for other_desc in other_table_results.split(','):
other_desc = other_desc.strip()
if other_desc.startswith("UNIQUE"):
if other_desc.startswith('UNIQUE'):
continue
other_name = other_desc.split(" ", 1)[0].strip('"')
other_name = other_desc.split(' ', 1)[0].strip('"')
if other_name == column:
relations[field_name] = (other_name, table)
break
@@ -200,17 +178,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
key_columns = []
# Schema for this table
cursor.execute(
"SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s",
[table_name, "table"],
)
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"])
results = cursor.fetchone()[0].strip()
results = results[results.index("(") + 1 : results.rindex(")")]
results = results[results.index('(') + 1:results.rindex(')')]
# Walk through and look for references to other tables. SQLite doesn't
# really have enforced references, but since it echoes out the SQL used
# to create the table we can look for REFERENCES statements used there.
for field_index, field_desc in enumerate(results.split(",")):
for field_index, field_desc in enumerate(results.split(',')):
field_desc = field_desc.strip()
if field_desc.startswith("UNIQUE"):
continue
@@ -219,9 +194,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
if not m:
continue
# This will append
# (column_name, referenced_table_name, referenced_column_name) to
# key_columns.
# This will append (column_name, referenced_table_name, referenced_column_name) to key_columns
key_columns.append(tuple(s.strip('"') for s in m.groups()))
return key_columns
@@ -232,40 +205,36 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
cursor.execute(
"SELECT sql, type FROM sqlite_master "
"WHERE tbl_name = %s AND type IN ('table', 'view')",
[table_name],
[table_name]
)
row = cursor.fetchone()
if row is None:
raise ValueError("Table %s does not exist" % table_name)
create_sql, table_type = row
if table_type == "view":
if table_type == 'view':
# Views don't have a primary key.
return None
fields_sql = create_sql[create_sql.index("(") + 1 : create_sql.rindex(")")]
for field_desc in fields_sql.split(","):
fields_sql = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
for field_desc in fields_sql.split(','):
field_desc = field_desc.strip()
m = re.match(
r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc
)
m = re.match(r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc)
if m:
return m[1] if m[1] else m[2]
return None
def _get_foreign_key_constraints(self, cursor, table_name):
constraints = {}
cursor.execute(
"PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
)
cursor.execute('PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name))
for row in cursor.fetchall():
# Remaining on_update/on_delete/match values are of no interest.
id_, _, table, from_, to = row[:5]
constraints["fk_%d" % id_] = {
"columns": [from_],
"primary_key": False,
"unique": False,
"foreign_key": (table, to),
"check": False,
"index": False,
constraints['fk_%d' % id_] = {
'columns': [from_],
'primary_key': False,
'unique': False,
'foreign_key': (table, to),
'check': False,
'index': False,
}
return constraints
@@ -280,21 +249,19 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
check_columns = []
braces_deep = 0
for token in tokens:
if token.match(sqlparse.tokens.Punctuation, "("):
if token.match(sqlparse.tokens.Punctuation, '('):
braces_deep += 1
elif token.match(sqlparse.tokens.Punctuation, ")"):
elif token.match(sqlparse.tokens.Punctuation, ')'):
braces_deep -= 1
if braces_deep < 0:
# End of columns and constraints for table definition.
break
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','):
# End of current column or constraint definition.
break
# Detect column or constraint definition by first token.
if is_constraint_definition is None:
is_constraint_definition = token.match(
sqlparse.tokens.Keyword, "CONSTRAINT"
)
is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT')
if is_constraint_definition:
continue
if is_constraint_definition:
@@ -305,7 +272,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
constraint_name = token.value[1:-1]
# Start constraint columns parsing after UNIQUE keyword.
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
unique = True
unique_braces_deep = braces_deep
elif unique:
@@ -325,10 +292,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
field_name = token.value
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
field_name = token.value[1:-1]
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
unique_columns = [field_name]
# Start constraint columns parsing after CHECK keyword.
if token.match(sqlparse.tokens.Keyword, "CHECK"):
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
check = True
check_braces_deep = braces_deep
elif check:
@@ -343,30 +310,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
if token.value[1:-1] in columns:
check_columns.append(token.value[1:-1])
unique_constraint = (
{
"unique": True,
"columns": unique_columns,
"primary_key": False,
"foreign_key": None,
"check": False,
"index": False,
}
if unique_columns
else None
)
check_constraint = (
{
"check": True,
"columns": check_columns,
"primary_key": False,
"unique": False,
"foreign_key": None,
"index": False,
}
if check_columns
else None
)
unique_constraint = {
'unique': True,
'columns': unique_columns,
'primary_key': False,
'foreign_key': None,
'check': False,
'index': False,
} if unique_columns else None
check_constraint = {
'check': True,
'columns': check_columns,
'primary_key': False,
'unique': False,
'foreign_key': None,
'index': False,
} if check_columns else None
return constraint_name, unique_constraint, check_constraint, token
def _parse_table_constraints(self, sql, columns):
@@ -378,33 +337,24 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
tokens = (token for token in statement.flatten() if not token.is_whitespace)
# Go to columns and constraint definition
for token in tokens:
if token.match(sqlparse.tokens.Punctuation, "("):
if token.match(sqlparse.tokens.Punctuation, '('):
break
# Parse columns and constraint definition
while True:
(
constraint_name,
unique,
check,
end_token,
) = self._parse_column_or_constraint_definition(tokens, columns)
constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns)
if unique:
if constraint_name:
constraints[constraint_name] = unique
else:
unnamed_constrains_index += 1
constraints[
"__unnamed_constraint_%s__" % unnamed_constrains_index
] = unique
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique
if check:
if constraint_name:
constraints[constraint_name] = check
else:
unnamed_constrains_index += 1
constraints[
"__unnamed_constraint_%s__" % unnamed_constrains_index
] = check
if end_token.match(sqlparse.tokens.Punctuation, ")"):
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check
if end_token.match(sqlparse.tokens.Punctuation, ')'):
break
return constraints
@@ -417,22 +367,19 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Find inline check constraints.
try:
table_schema = cursor.execute(
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s"
% (self.connection.ops.quote_name(table_name),)
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s" % (
self.connection.ops.quote_name(table_name),
)
).fetchone()[0]
except TypeError:
# table_name is a view.
pass
else:
columns = {
info.name for info in self.get_table_description(cursor, table_name)
}
columns = {info.name for info in self.get_table_description(cursor, table_name)}
constraints.update(self._parse_table_constraints(table_schema, columns))
# Get the index info
cursor.execute(
"PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
)
cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
for row in cursor.fetchall():
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
# columns. Discard last 2 columns if there.
@@ -442,7 +389,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
)
# There's at most one row.
(sql,) = cursor.fetchone() or (None,)
sql, = cursor.fetchone() or (None,)
# Inline constraints are already detected in
# _parse_table_constraints(). The reasons to avoid fetching inline
# constraints from `PRAGMA index_list` are:
@@ -453,9 +400,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# An inline constraint
continue
# Get the index info for that index
cursor.execute(
"PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
)
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
for index_rank, column_rank, column in cursor.fetchall():
if index not in constraints:
constraints[index] = {
@@ -466,14 +411,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"check": False,
"index": True,
}
constraints[index]["columns"].append(column)
constraints[index]['columns'].append(column)
# Add type and column orders for indexes
if constraints[index]["index"]:
if constraints[index]['index']:
# SQLite doesn't support any index type other than b-tree
constraints[index]["type"] = Index.suffix
constraints[index]['type'] = Index.suffix
orders = self._get_index_columns_orders(sql)
if orders is not None:
constraints[index]["orders"] = orders
constraints[index]['orders'] = orders
# Get the PK
pk_column = self.get_primary_key_column(cursor, table_name)
if pk_column:
@@ -496,30 +441,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
tokens = sqlparse.parse(sql)[0]
for token in tokens:
if isinstance(token, sqlparse.sql.Parenthesis):
columns = str(token).strip("()").split(", ")
return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
columns = str(token).strip('()').split(', ')
return ['DESC' if info.endswith('DESC') else 'ASC' for info in columns]
return None
def _get_column_collations(self, cursor, table_name):
row = cursor.execute(
"""
row = cursor.execute("""
SELECT sql
FROM sqlite_master
WHERE type = 'table' AND name = %s
""",
[table_name],
).fetchone()
""", [table_name]).fetchone()
if not row:
return {}
sql = row[0]
columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
columns = str(sqlparse.parse(sql)[0][-1]).strip('()').split(', ')
collations = {}
for column in columns:
tokens = column[1:].split()
column_name = tokens[0].strip('"')
for index, token in enumerate(tokens):
if token == "COLLATE":
if token == 'COLLATE':
collation = tokens[index + 1]
break
else:
@@ -15,15 +15,12 @@ from django.utils.functional import cached_property
class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = "text"
cast_char_field_without_max_length = 'text'
cast_data_types = {
"DateField": "TEXT",
"DateTimeField": "TEXT",
'DateField': 'TEXT',
'DateTimeField': 'TEXT',
}
explain_prefix = "EXPLAIN QUERY PLAN"
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
# SQLite. Use JSON_TYPE() instead.
jsonfield_datatype_values = frozenset(["null", "false", "true"])
explain_prefix = 'EXPLAIN QUERY PLAN'
def bulk_batch_size(self, fields, objs):
"""
@@ -54,14 +51,14 @@ class DatabaseOperations(BaseDatabaseOperations):
else:
if isinstance(output_field, bad_fields):
raise NotSupportedError(
"You cannot use Sum, Avg, StdDev, and Variance "
"aggregations on date/time fields in sqlite3 "
"since date/time is saved as text."
'You cannot use Sum, Avg, StdDev, and Variance '
'aggregations on date/time fields in sqlite3 '
'since date/time is saved as text.'
)
if (
isinstance(expression, models.Aggregate)
and expression.distinct
and len(expression.source_expressions) > 1
isinstance(expression, models.Aggregate) and
expression.distinct and
len(expression.source_expressions) > 1
):
raise NotSupportedError(
"SQLite doesn't support DISTINCT on aggregate functions "
@@ -76,13 +73,6 @@ class DatabaseOperations(BaseDatabaseOperations):
"""
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the list of returned data.
"""
return cursor.fetchall()
def format_for_duration_arithmetic(self, sql):
"""Do nothing since formatting is handled in the custom function."""
return sql
@@ -104,32 +94,26 @@ class DatabaseOperations(BaseDatabaseOperations):
def _convert_tznames_to_sql(self, tzname):
if tzname and settings.USE_TZ:
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
return "NULL", "NULL"
return 'NULL', 'NULL'
def datetime_cast_date_sql(self, field_name, tzname):
return "django_datetime_cast_date(%s, %s, %s)" % (
field_name,
*self._convert_tznames_to_sql(tzname),
return 'django_datetime_cast_date(%s, %s, %s)' % (
field_name, *self._convert_tznames_to_sql(tzname),
)
def datetime_cast_time_sql(self, field_name, tzname):
return "django_datetime_cast_time(%s, %s, %s)" % (
field_name,
*self._convert_tznames_to_sql(tzname),
return 'django_datetime_cast_time(%s, %s, %s)' % (
field_name, *self._convert_tznames_to_sql(tzname),
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
return "django_datetime_extract('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
return "django_datetime_trunc('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
)
def time_extract_sql(self, lookup_type, field_name):
@@ -151,11 +135,11 @@ class DatabaseOperations(BaseDatabaseOperations):
if len(params) > BATCH_SIZE:
results = ()
for index in range(0, len(params), BATCH_SIZE):
chunk = params[index : index + BATCH_SIZE]
chunk = params[index:index + BATCH_SIZE]
results += self._quote_params_for_last_executed_query(chunk)
return results
sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
# Bypass Django's wrappers and use the underlying sqlite3 connection
# to avoid logging this query - it would trigger infinite recursion.
cursor = self.connection.connection.cursor()
@@ -167,9 +151,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def last_executed_query(self, cursor, sql, params):
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
# pysqlite_statement_bind_parameters(
# self->statement, parameters, allow_8bit_chars
# );
# pysqlite_statement_bind_parameters(self->statement, parameters, allow_8bit_chars);
# Unfortunately there is no way to reach self->statement from Python,
# so we quote and substitute parameters manually.
if params:
@@ -222,20 +204,14 @@ class DatabaseOperations(BaseDatabaseOperations):
if tables and allow_cascade:
# Simulate TRUNCATE CASCADE by recursively collecting the tables
# referencing the tables to be flushed.
tables = set(
chain.from_iterable(self._references_graph(table) for table in tables)
)
sql = [
"%s %s %s;"
% (
style.SQL_KEYWORD("DELETE"),
style.SQL_KEYWORD("FROM"),
style.SQL_FIELD(self.quote_name(table)),
)
for table in tables
]
tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
sql = ['%s %s %s;' % (
style.SQL_KEYWORD('DELETE'),
style.SQL_KEYWORD('FROM'),
style.SQL_FIELD(self.quote_name(table))
) for table in tables]
if reset_sequences:
sequences = [{"table": table} for table in tables]
sequences = [{'table': table} for table in tables]
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
@@ -243,18 +219,17 @@ class DatabaseOperations(BaseDatabaseOperations):
if not sequences:
return []
return [
"%s %s %s %s = 0 %s %s %s (%s);"
% (
style.SQL_KEYWORD("UPDATE"),
style.SQL_TABLE(self.quote_name("sqlite_sequence")),
style.SQL_KEYWORD("SET"),
style.SQL_FIELD(self.quote_name("seq")),
style.SQL_KEYWORD("WHERE"),
style.SQL_FIELD(self.quote_name("name")),
style.SQL_KEYWORD("IN"),
", ".join(
["'%s'" % sequence_info["table"] for sequence_info in sequences]
),
'%s %s %s %s = 0 %s %s %s (%s);' % (
style.SQL_KEYWORD('UPDATE'),
style.SQL_TABLE(self.quote_name('sqlite_sequence')),
style.SQL_KEYWORD('SET'),
style.SQL_FIELD(self.quote_name('seq')),
style.SQL_KEYWORD('WHERE'),
style.SQL_FIELD(self.quote_name('name')),
style.SQL_KEYWORD('IN'),
', '.join([
"'%s'" % sequence_info['table'] for sequence_info in sequences
]),
),
]
@@ -263,7 +238,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
if hasattr(value, 'resolve_expression'):
return value
# SQLite doesn't support tz-aware datetimes
@@ -271,10 +246,7 @@ class DatabaseOperations(BaseDatabaseOperations):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError(
"SQLite backend does not support timezone-aware datetimes when "
"USE_TZ is False."
)
raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")
return str(value)
@@ -283,7 +255,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
if hasattr(value, 'resolve_expression'):
return value
# SQLite doesn't support tz-aware datetimes
@@ -295,17 +267,17 @@ class DatabaseOperations(BaseDatabaseOperations):
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == "DateTimeField":
if internal_type == 'DateTimeField':
converters.append(self.convert_datetimefield_value)
elif internal_type == "DateField":
elif internal_type == 'DateField':
converters.append(self.convert_datefield_value)
elif internal_type == "TimeField":
elif internal_type == 'TimeField':
converters.append(self.convert_timefield_value)
elif internal_type == "DecimalField":
elif internal_type == 'DecimalField':
converters.append(self.get_decimalfield_converter(expression))
elif internal_type == "UUIDField":
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
elif internal_type == "BooleanField":
elif internal_type in ('NullBooleanField', 'BooleanField'):
converters.append(self.convert_booleanfield_value)
return converters
@@ -334,22 +306,15 @@ class DatabaseOperations(BaseDatabaseOperations):
# float inaccuracy must be removed.
create_decimal = decimal.Context(prec=15).create_decimal_from_float
if isinstance(expression, Col):
quantize_value = decimal.Decimal(1).scaleb(
-expression.output_field.decimal_places
)
quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places)
def converter(value, expression, connection):
if value is not None:
return create_decimal(value).quantize(
quantize_value, context=expression.output_field.context
)
return create_decimal(value).quantize(quantize_value, context=expression.output_field.context)
else:
def converter(value, expression, connection):
if value is not None:
return create_decimal(value)
return converter
def convert_uuidfield_value(self, value, expression, connection):
@@ -362,25 +327,26 @@ class DatabaseOperations(BaseDatabaseOperations):
def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join(
"SELECT %s" % ", ".join(row) for row in placeholder_rows
"SELECT %s" % ", ".join(row)
for row in placeholder_rows
)
def combine_expression(self, connector, sub_expressions):
# SQLite doesn't have a ^ operator, so use the user-defined POWER
# function that's registered in connect().
if connector == "^":
return "POWER(%s)" % ",".join(sub_expressions)
elif connector == "#":
return "BITXOR(%s)" % ",".join(sub_expressions)
if connector == '^':
return 'POWER(%s)' % ','.join(sub_expressions)
elif connector == '#':
return 'BITXOR(%s)' % ','.join(sub_expressions)
return super().combine_expression(connector, sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
if connector not in ["+", "-", "*", "/"]:
raise DatabaseError("Invalid connector for timedelta: %s." % connector)
if connector not in ['+', '-']:
raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
fn_params = ["'%s'" % connector] + sub_expressions
if len(fn_params) > 3:
raise ValueError("Too many params for timedelta operations.")
return "django_format_dtdelta(%s)" % ", ".join(fn_params)
raise ValueError('Too many params for timedelta operations.')
return "django_format_dtdelta(%s)" % ', '.join(fn_params)
def integer_field_range(self, internal_type):
# SQLite doesn't enforce any integer constraints
@@ -390,27 +356,9 @@ class DatabaseOperations(BaseDatabaseOperations):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
if internal_type == "TimeField":
return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
if internal_type == 'TimeField':
return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
def insert_statement(self, ignore_conflicts=False):
return (
"INSERT OR IGNORE INTO"
if ignore_conflicts
else super().insert_statement(ignore_conflicts)
)
def return_insert_columns(self, fields):
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
if not fields:
return "", ()
columns = [
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()
return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
@@ -14,9 +14,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_delete_table = "DROP TABLE %(table)s"
sql_create_fk = None
sql_create_inline_fk = (
"REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
)
sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
sql_delete_unique = "DROP INDEX %(name)s"
@@ -25,11 +23,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# disabled. Enforce it here for the duration of the schema edition.
if not self.connection.disable_constraint_checking():
raise NotSupportedError(
"SQLite schema editor cannot be used while foreign key "
"constraint checks are enabled. Make sure to disable them "
"before entering a transaction.atomic() context because "
"SQLite does not support disabling them in the middle of "
"a multi-statement transaction."
'SQLite schema editor cannot be used while foreign key '
'constraint checks are enabled. Make sure to disable them '
'before entering a transaction.atomic() context because '
'SQLite does not support disabling them in the middle of '
'a multi-statement transaction.'
)
return super().__enter__()
@@ -44,7 +42,6 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# security hardening).
try:
import sqlite3
value = sqlite3.adapt(value)
except ImportError:
pass
@@ -56,7 +53,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
elif isinstance(value, (Decimal, float, int)):
return str(value)
elif isinstance(value, str):
return "'%s'" % value.replace("'", "''")
return "'%s'" % value.replace("\'", "\'\'")
elif value is None:
return "NULL"
elif isinstance(value, (bytes, bytearray, memoryview)):
@@ -65,13 +62,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# character.
return "X'%s'" % value.hex()
else:
raise ValueError(
"Cannot quote parameter value %r of type %s" % (value, type(value))
)
raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value)))
def _is_referenced_by_fk_constraint(
self, table_name, column_name=None, ignore_self=False
):
def _is_referenced_by_fk_constraint(self, table_name, column_name=None, ignore_self=False):
"""
Return whether or not the provided table name is referenced by another
one. If `column_name` is specified, only references pointing to that
@@ -82,36 +75,23 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
for other_table in self.connection.introspection.get_table_list(cursor):
if ignore_self and other_table.name == table_name:
continue
constraints = (
self.connection.introspection._get_foreign_key_constraints(
cursor, other_table.name
)
)
constraints = self.connection.introspection._get_foreign_key_constraints(cursor, other_table.name)
for constraint in constraints.values():
constraint_table, constraint_column = constraint["foreign_key"]
if constraint_table == table_name and (
column_name is None or constraint_column == column_name
):
constraint_table, constraint_column = constraint['foreign_key']
if (constraint_table == table_name and
(column_name is None or constraint_column == column_name)):
return True
return False
def alter_db_table(
self, model, old_db_table, new_db_table, disable_constraints=True
):
if (
not self.connection.features.supports_atomic_references_rename
and disable_constraints
and self._is_referenced_by_fk_constraint(old_db_table)
):
def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True):
if (not self.connection.features.supports_atomic_references_rename and
disable_constraints and self._is_referenced_by_fk_constraint(old_db_table)):
if self.connection.in_atomic_block:
raise NotSupportedError(
(
"Renaming the %r table while in a transaction is not "
"supported on SQLite < 3.26 because it would break referential "
"integrity. Try adding `atomic = False` to the Migration class."
)
% old_db_table
)
raise NotSupportedError((
'Renaming the %r table while in a transaction is not '
'supported on SQLite < 3.26 because it would break referential '
'integrity. Try adding `atomic = False` to the Migration class.'
) % old_db_table)
self.connection.enable_constraint_checking()
super().alter_db_table(model, old_db_table, new_db_table)
self.connection.disable_constraint_checking()
@@ -124,56 +104,42 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
old_field_name = old_field.name
table_name = model._meta.db_table
_, old_column_name = old_field.get_attname_column()
if (
new_field.name != old_field_name
and not self.connection.features.supports_atomic_references_rename
and self._is_referenced_by_fk_constraint(
table_name, old_column_name, ignore_self=True
)
):
if (new_field.name != old_field_name and
not self.connection.features.supports_atomic_references_rename and
self._is_referenced_by_fk_constraint(table_name, old_column_name, ignore_self=True)):
if self.connection.in_atomic_block:
raise NotSupportedError(
(
"Renaming the %r.%r column while in a transaction is not "
"supported on SQLite < 3.26 because it would break referential "
"integrity. Try adding `atomic = False` to the Migration class."
)
% (model._meta.db_table, old_field_name)
)
raise NotSupportedError((
'Renaming the %r.%r column while in a transaction is not '
'supported on SQLite < 3.26 because it would break referential '
'integrity. Try adding `atomic = False` to the Migration class.'
) % (model._meta.db_table, old_field_name))
with atomic(self.connection.alias):
super().alter_field(model, old_field, new_field, strict=strict)
# Follow SQLite's documented procedure for performing changes
# that don't affect the on-disk content.
# https://sqlite.org/lang_altertable.html#otheralter
with self.connection.cursor() as cursor:
schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
0
]
cursor.execute("PRAGMA writable_schema = 1")
schema_version = cursor.execute('PRAGMA schema_version').fetchone()[0]
cursor.execute('PRAGMA writable_schema = 1')
references_template = ' REFERENCES "%s" ("%%s") ' % table_name
new_column_name = new_field.get_attname_column()[1]
search = references_template % old_column_name
replacement = references_template % new_column_name
cursor.execute(
"UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
(search, replacement),
)
cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))
cursor.execute("PRAGMA writable_schema = 0")
cursor.execute('UPDATE sqlite_master SET sql = replace(sql, %s, %s)', (search, replacement))
cursor.execute('PRAGMA schema_version = %d' % (schema_version + 1))
cursor.execute('PRAGMA writable_schema = 0')
# The integrity check will raise an exception and rollback
# the transaction if the sqlite_master updates corrupt the
# database.
cursor.execute("PRAGMA integrity_check")
cursor.execute('PRAGMA integrity_check')
# Perform a VACUUM to refresh the database representation from
# the sqlite_master table.
with self.connection.cursor() as cursor:
cursor.execute("VACUUM")
cursor.execute('VACUUM')
else:
super().alter_field(model, old_field, new_field, strict=strict)
def _remake_table(
self, model, create_field=None, delete_field=None, alter_field=None
):
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None):
"""
Shortcut to transform a model from old_model into new_model
@@ -194,7 +160,6 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# to an altered field.
def is_self_referential(f):
return f.is_relation and f.remote_field.model is model
# Work out the new fields dict / mapping
body = {
f.name: f.clone() if is_self_referential(f) else f
@@ -202,18 +167,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
}
# Since mapping might mix column names and default values,
# its values must be already quoted.
mapping = {
f.column: self.quote_name(f.column)
for f in model._meta.local_concrete_fields
}
mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields}
# This maps field names (not columns) for things like unique_together
rename_mapping = {}
# If any of the new or altered fields is introducing a new PK,
# remove the old one
restore_pk_field = None
if getattr(create_field, "primary_key", False) or (
alter_field and getattr(alter_field[1], "primary_key", False)
):
if getattr(create_field, 'primary_key', False) or (
alter_field and getattr(alter_field[1], 'primary_key', False)):
for name, field in list(body.items()):
if field.primary_key:
field.primary_key = False
@@ -237,8 +198,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
body[new_field.name] = new_field
if old_field.null and not new_field.null:
case_sql = "coalesce(%(col)s, %(default)s)" % {
"col": self.quote_name(old_field.column),
"default": self.quote_value(self.effective_default(new_field)),
'col': self.quote_name(old_field.column),
'default': self.quote_value(self.effective_default(new_field))
}
mapping[new_field.column] = case_sql
else:
@@ -249,10 +210,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
del body[delete_field.name]
del mapping[delete_field.column]
# Remove any implicit M2M tables
if (
delete_field.many_to_many
and delete_field.remote_field.through._meta.auto_created
):
if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created:
return self.delete_model(delete_field.remote_field.through)
# Work inside a new app registry
apps = Apps()
@@ -274,7 +232,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
indexes = model._meta.indexes
if delete_field:
indexes = [
index for index in indexes if delete_field.name not in index.fields
index for index in indexes
if delete_field.name not in index.fields
]
constraints = list(model._meta.constraints)
@@ -290,57 +249,52 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# This wouldn't be required if the schema editor was operating on model
# states instead of rendered models.
meta_contents = {
"app_label": model._meta.app_label,
"db_table": model._meta.db_table,
"unique_together": unique_together,
"index_together": index_together,
"indexes": indexes,
"constraints": constraints,
"apps": apps,
'app_label': model._meta.app_label,
'db_table': model._meta.db_table,
'unique_together': unique_together,
'index_together': index_together,
'indexes': indexes,
'constraints': constraints,
'apps': apps,
}
meta = type("Meta", (), meta_contents)
body_copy["Meta"] = meta
body_copy["__module__"] = model.__module__
body_copy['Meta'] = meta
body_copy['__module__'] = model.__module__
type(model._meta.object_name, model.__bases__, body_copy)
# Construct a model with a renamed table name.
body_copy = copy.deepcopy(body)
meta_contents = {
"app_label": model._meta.app_label,
"db_table": "new__%s" % strip_quotes(model._meta.db_table),
"unique_together": unique_together,
"index_together": index_together,
"indexes": indexes,
"constraints": constraints,
"apps": apps,
'app_label': model._meta.app_label,
'db_table': 'new__%s' % strip_quotes(model._meta.db_table),
'unique_together': unique_together,
'index_together': index_together,
'indexes': indexes,
'constraints': constraints,
'apps': apps,
}
meta = type("Meta", (), meta_contents)
body_copy["Meta"] = meta
body_copy["__module__"] = model.__module__
new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy)
body_copy['Meta'] = meta
body_copy['__module__'] = model.__module__
new_model = type('New%s' % model._meta.object_name, model.__bases__, body_copy)
# Create a new table with the updated schema.
self.create_model(new_model)
# Copy data from the old table into the new table
self.execute(
"INSERT INTO %s (%s) SELECT %s FROM %s"
% (
self.quote_name(new_model._meta.db_table),
", ".join(self.quote_name(x) for x in mapping),
", ".join(mapping.values()),
self.quote_name(model._meta.db_table),
)
)
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
self.quote_name(new_model._meta.db_table),
', '.join(self.quote_name(x) for x in mapping),
', '.join(mapping.values()),
self.quote_name(model._meta.db_table),
))
# Delete the old table to make way for the new
self.delete_model(model, handle_autom2m=False)
# Rename the new table to take way for the old
self.alter_db_table(
new_model,
new_model._meta.db_table,
model._meta.db_table,
new_model, new_model._meta.db_table, model._meta.db_table,
disable_constraints=False,
)
@@ -357,17 +311,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
super().delete_model(model)
else:
# Delete the table (and only that)
self.execute(
self.sql_delete_table
% {
"table": self.quote_name(model._meta.db_table),
}
)
self.execute(self.sql_delete_table % {
"table": self.quote_name(model._meta.db_table),
})
# Remove all deferred statements referencing the deleted table.
for sql in list(self.deferred_sql):
if isinstance(sql, Statement) and sql.references_table(
model._meta.db_table
):
if isinstance(sql, Statement) and sql.references_table(model._meta.db_table):
self.deferred_sql.remove(sql)
def add_field(self, model, field):
@@ -394,40 +343,21 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# For everything else, remake.
else:
# It might not actually have a column behind it
if field.db_parameters(connection=self.connection)["type"] is None:
if field.db_parameters(connection=self.connection)['type'] is None:
return
self._remake_table(model, delete_field=field)
def _alter_field(
self,
model,
old_field,
new_field,
old_type,
new_type,
old_db_params,
new_db_params,
strict=False,
):
def _alter_field(self, model, old_field, new_field, old_type, new_type,
old_db_params, new_db_params, strict=False):
"""Perform a "physical" (non-ManyToMany) field update."""
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
# changed and there aren't any constraints.
if (
self.connection.features.can_alter_table_rename_column
and old_field.column != new_field.column
and self.column_sql(model, old_field) == self.column_sql(model, new_field)
and not (
old_field.remote_field
and old_field.db_constraint
or new_field.remote_field
and new_field.db_constraint
)
):
return self.execute(
self._rename_field_sql(
model._meta.db_table, old_field, new_field, new_type
)
)
if (self.connection.features.can_alter_table_rename_column and
old_field.column != new_field.column and
self.column_sql(model, old_field) == self.column_sql(model, new_field) and
not (old_field.remote_field and old_field.db_constraint or
new_field.remote_field and new_field.db_constraint)):
return self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
# Alter by remaking table
self._remake_table(model, alter_field=(old_field, new_field))
# Rebuild tables with FKs pointing to this field.
@@ -455,25 +385,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _alter_many_to_many(self, model, old_field, new_field, strict):
"""Alter M2Ms to repoint their to= endpoints."""
if (
old_field.remote_field.through._meta.db_table
== new_field.remote_field.through._meta.db_table
):
# The field name didn't change, but some options did, so we have to
# propagate this altering.
if old_field.remote_field.through._meta.db_table == new_field.remote_field.through._meta.db_table:
# The field name didn't change, but some options did; we have to propagate this altering.
self._remake_table(
old_field.remote_field.through,
alter_field=(
# The field that points to the target model is needed, so
# we can tell alter_field to change it - this is
# m2m_reverse_field_name() (as opposed to m2m_field_name(),
# which points to our model).
old_field.remote_field.through._meta.get_field(
old_field.m2m_reverse_field_name()
),
new_field.remote_field.through._meta.get_field(
new_field.m2m_reverse_field_name()
),
# We need the field that points to the target model, so we can tell alter_field to change it -
# this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model)
old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()),
new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()),
),
)
return
@@ -481,51 +401,34 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Make a new through table
self.create_model(new_field.remote_field.through)
# Copy the data across
self.execute(
"INSERT INTO %s (%s) SELECT %s FROM %s"
% (
self.quote_name(new_field.remote_field.through._meta.db_table),
", ".join(
[
"id",
new_field.m2m_column_name(),
new_field.m2m_reverse_name(),
]
),
", ".join(
[
"id",
old_field.m2m_column_name(),
old_field.m2m_reverse_name(),
]
),
self.quote_name(old_field.remote_field.through._meta.db_table),
)
)
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
self.quote_name(new_field.remote_field.through._meta.db_table),
', '.join([
"id",
new_field.m2m_column_name(),
new_field.m2m_reverse_name(),
]),
', '.join([
"id",
old_field.m2m_column_name(),
old_field.m2m_reverse_name(),
]),
self.quote_name(old_field.remote_field.through._meta.db_table),
))
# Delete the old through table
self.delete_model(old_field.remote_field.through)
def add_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
constraint.condition
or constraint.contains_expressions
or constraint.include
or constraint.deferrable
):
if isinstance(constraint, UniqueConstraint) and constraint.condition:
super().add_constraint(model, constraint)
else:
self._remake_table(model)
def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
constraint.condition
or constraint.contains_expressions
or constraint.include
or constraint.deferrable
):
if isinstance(constraint, UniqueConstraint) and constraint.condition:
super().remove_constraint(model, constraint)
else:
self._remake_table(model)
def _collate_sql(self, collation):
return "COLLATE " + collation
return ' COLLATE ' + collation
@@ -7,9 +7,8 @@ import time
from contextlib import contextmanager
from django.db import NotSupportedError
from django.utils.dateparse import parse_time
logger = logging.getLogger("django.db.backends")
logger = logging.getLogger('django.db.backends')
class CursorWrapper:
@@ -17,7 +16,7 @@ class CursorWrapper:
self.cursor = cursor
self.db = db
WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
def __getattr__(self, attr):
cursor_attr = getattr(self.cursor, attr)
@@ -50,8 +49,8 @@ class CursorWrapper:
# database driver may support them (e.g. cx_Oracle).
if kparams is not None and not self.db.features.supports_callproc_kwargs:
raise NotSupportedError(
"Keyword parameters for callproc are not supported on this "
"database backend."
'Keyword parameters for callproc are not supported on this '
'database backend.'
)
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
@@ -64,17 +63,13 @@ class CursorWrapper:
return self.cursor.callproc(procname, params, kparams)
def execute(self, sql, params=None):
return self._execute_with_wrappers(
sql, params, many=False, executor=self._execute
)
return self._execute_with_wrappers(sql, params, many=False, executor=self._execute)
def executemany(self, sql, param_list):
return self._execute_with_wrappers(
sql, param_list, many=True, executor=self._executemany
)
return self._execute_with_wrappers(sql, param_list, many=True, executor=self._executemany)
def _execute_with_wrappers(self, sql, params, many, executor):
context = {"connection": self.db, "cursor": self}
context = {'connection': self.db, 'cursor': self}
for wrapper in reversed(self.db.execute_wrappers):
executor = functools.partial(wrapper, executor)
return executor(sql, params, many, context)
@@ -107,9 +102,7 @@ class CursorDebugWrapper(CursorWrapper):
return super().executemany(sql, param_list)
@contextmanager
def debug_sql(
self, sql=None, params=None, use_last_executed_query=False, many=False
):
def debug_sql(self, sql=None, params=None, use_last_executed_query=False, many=False):
start = time.monotonic()
try:
yield
@@ -119,65 +112,40 @@ class CursorDebugWrapper(CursorWrapper):
if use_last_executed_query:
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
try:
times = len(params) if many else ""
times = len(params) if many else ''
except TypeError:
# params could be an iterator.
times = "?"
self.db.queries_log.append(
{
"sql": "%s times: %s" % (times, sql) if many else sql,
"time": "%.3f" % duration,
}
)
times = '?'
self.db.queries_log.append({
'sql': '%s times: %s' % (times, sql) if many else sql,
'time': '%.3f' % duration,
})
logger.debug(
"(%.3f) %s; args=%s; alias=%s",
'(%.3f) %s; args=%s',
duration,
sql,
params,
self.db.alias,
extra={
"duration": duration,
"sql": sql,
"params": params,
"alias": self.db.alias,
},
extra={'duration': duration, 'sql': sql, 'params': params},
)
def split_tzname_delta(tzname):
"""
Split a time zone name into a 3-tuple of (name, sign, offset).
"""
for sign in ["+", "-"]:
if sign in tzname:
name, offset = tzname.rsplit(sign, 1)
if offset and parse_time(offset):
return name, sign, offset
return tzname, None, None
###############################################
# Converters from database (string) to Python #
###############################################
def typecast_date(s):
return (
datetime.date(*map(int, s.split("-"))) if s else None
) # return None if s is null
return datetime.date(*map(int, s.split('-'))) if s else None # return None if s is null
def typecast_time(s): # does NOT store time zone information
if not s:
return None
hour, minutes, seconds = s.split(":")
if "." in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split(".")
hour, minutes, seconds = s.split(':')
if '.' in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split('.')
else:
microseconds = "0"
return datetime.time(
int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
)
microseconds = '0'
return datetime.time(int(hour), int(minutes), int(seconds), int((microseconds + '000000')[:6]))
def typecast_timestamp(s): # does NOT store time zone information
@@ -185,29 +153,25 @@ def typecast_timestamp(s): # does NOT store time zone information
# "2005-07-29 09:56:00-05"
if not s:
return None
if " " not in s:
if ' ' not in s:
return typecast_date(s)
d, t = s.split()
# Remove timezone information.
if "-" in t:
t, _ = t.split("-", 1)
elif "+" in t:
t, _ = t.split("+", 1)
dates = d.split("-")
times = t.split(":")
if '-' in t:
t, _ = t.split('-', 1)
elif '+' in t:
t, _ = t.split('+', 1)
dates = d.split('-')
times = t.split(':')
seconds = times[2]
if "." in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split(".")
if '.' in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split('.')
else:
microseconds = "0"
microseconds = '0'
return datetime.datetime(
int(dates[0]),
int(dates[1]),
int(dates[2]),
int(times[0]),
int(times[1]),
int(seconds),
int((microseconds + "000000")[:6]),
int(dates[0]), int(dates[1]), int(dates[2]),
int(times[0]), int(times[1]), int(seconds),
int((microseconds + '000000')[:6])
)
@@ -215,7 +179,6 @@ def typecast_timestamp(s): # does NOT store time zone information
# Converters from Python to database (string) #
###############################################
def split_identifier(identifier):
"""
Split an SQL identifier into a two element tuple of (namespace, name).
@@ -226,7 +189,7 @@ def split_identifier(identifier):
try:
namespace, name = identifier.split('"."')
except ValueError:
namespace, name = "", identifier
namespace, name = '', identifier
return namespace.strip('"'), name.strip('"')
@@ -244,11 +207,7 @@ def truncate_name(identifier, length=None, hash_len=4):
return identifier
digest = names_digest(name, length=hash_len)
return "%s%s%s" % (
'%s"."' % namespace if namespace else "",
name[: length - hash_len],
digest,
)
return '%s%s%s' % ('%s"."' % namespace if namespace else '', name[:length - hash_len], digest)
def names_digest(*args, length):
@@ -273,9 +232,7 @@ def format_number(value, max_digits, decimal_places):
if max_digits is not None:
context.prec = max_digits
if decimal_places is not None:
value = value.quantize(
decimal.Decimal(1).scaleb(-decimal_places), context=context
)
value = value.quantize(decimal.Decimal(1).scaleb(-decimal_places), context=context)
else:
context.traps[decimal.Rounded] = 1
value = context.create_decimal(value)