测试gitnore

This commit is contained in:
ladeng07
2022-05-06 15:45:57 +08:00
parent 12f390949b
commit 51552904f9
2347 changed files with 120102 additions and 53549 deletions
@@ -14,15 +14,18 @@ import warnings
from itertools import chain
from sqlite3 import dbapi2 as Database
import pytz
from django.core.exceptions import ImproperlyConfigured
from django.db import IntegrityError
from django.db.backends import utils as backend_utils
from django.db.backends.base.base import BaseDatabaseWrapper, timezone_constructor
from django.db.backends.base.base import BaseDatabaseWrapper
from django.utils import timezone
from django.utils.asyncio import async_unsafe
from django.utils.dateparse import parse_datetime, parse_time
from django.utils.duration import duration_microseconds
from django.utils.regex_helper import _lazy_re_compile
from django.utils.version import PY38
from .client import DatabaseClient
from .creation import DatabaseCreation
@@ -46,11 +49,9 @@ def none_guard(func):
are NULL. This decorator simplifies the implementation of this for the
custom functions registered below.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
return None if None in args else func(*args, **kwargs)
return wrapper
@@ -59,19 +60,19 @@ def list_aggregate(function):
Return an aggregate class that accumulates values in a list and applies
the provided function to the data.
"""
return type("ListAggregate", (list,), {"finalize": function, "step": list.append})
return type('ListAggregate', (list,), {'finalize': function, 'step': list.append})
def check_sqlite_version():
if Database.sqlite_version_info < (3, 9, 0):
raise ImproperlyConfigured(
"SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version
'SQLite 3.9.0 or later is required (found %s).' % Database.sqlite_version
)
check_sqlite_version()
Database.register_converter("bool", b"1".__eq__)
Database.register_converter("bool", b'1'.__eq__)
Database.register_converter("time", decoder(parse_time))
Database.register_converter("datetime", decoder(parse_datetime))
Database.register_converter("timestamp", decoder(parse_datetime))
@@ -80,69 +81,70 @@ Database.register_adapter(decimal.Decimal, str)
class DatabaseWrapper(BaseDatabaseWrapper):
vendor = "sqlite"
display_name = "SQLite"
vendor = 'sqlite'
display_name = 'SQLite'
# SQLite doesn't actually support most of these types, but it "does the right
# thing" given more verbose field definitions, so leave them as is so that
# schema inspection is more useful.
data_types = {
"AutoField": "integer",
"BigAutoField": "integer",
"BinaryField": "BLOB",
"BooleanField": "bool",
"CharField": "varchar(%(max_length)s)",
"DateField": "date",
"DateTimeField": "datetime",
"DecimalField": "decimal",
"DurationField": "bigint",
"FileField": "varchar(%(max_length)s)",
"FilePathField": "varchar(%(max_length)s)",
"FloatField": "real",
"IntegerField": "integer",
"BigIntegerField": "bigint",
"IPAddressField": "char(15)",
"GenericIPAddressField": "char(39)",
"JSONField": "text",
"OneToOneField": "integer",
"PositiveBigIntegerField": "bigint unsigned",
"PositiveIntegerField": "integer unsigned",
"PositiveSmallIntegerField": "smallint unsigned",
"SlugField": "varchar(%(max_length)s)",
"SmallAutoField": "integer",
"SmallIntegerField": "smallint",
"TextField": "text",
"TimeField": "time",
"UUIDField": "char(32)",
'AutoField': 'integer',
'BigAutoField': 'integer',
'BinaryField': 'BLOB',
'BooleanField': 'bool',
'CharField': 'varchar(%(max_length)s)',
'DateField': 'date',
'DateTimeField': 'datetime',
'DecimalField': 'decimal',
'DurationField': 'bigint',
'FileField': 'varchar(%(max_length)s)',
'FilePathField': 'varchar(%(max_length)s)',
'FloatField': 'real',
'IntegerField': 'integer',
'BigIntegerField': 'bigint',
'IPAddressField': 'char(15)',
'GenericIPAddressField': 'char(39)',
'JSONField': 'text',
'NullBooleanField': 'bool',
'OneToOneField': 'integer',
'PositiveBigIntegerField': 'bigint unsigned',
'PositiveIntegerField': 'integer unsigned',
'PositiveSmallIntegerField': 'smallint unsigned',
'SlugField': 'varchar(%(max_length)s)',
'SmallAutoField': 'integer',
'SmallIntegerField': 'smallint',
'TextField': 'text',
'TimeField': 'time',
'UUIDField': 'char(32)',
}
data_type_check_constraints = {
"PositiveBigIntegerField": '"%(column)s" >= 0',
"JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
"PositiveIntegerField": '"%(column)s" >= 0',
"PositiveSmallIntegerField": '"%(column)s" >= 0',
'PositiveBigIntegerField': '"%(column)s" >= 0',
'JSONField': '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
'PositiveIntegerField': '"%(column)s" >= 0',
'PositiveSmallIntegerField': '"%(column)s" >= 0',
}
data_types_suffix = {
"AutoField": "AUTOINCREMENT",
"BigAutoField": "AUTOINCREMENT",
"SmallAutoField": "AUTOINCREMENT",
'AutoField': 'AUTOINCREMENT',
'BigAutoField': 'AUTOINCREMENT',
'SmallAutoField': 'AUTOINCREMENT',
}
# SQLite requires LIKE statements to include an ESCAPE clause if the value
# being escaped has a percent or underscore in it.
# See https://www.sqlite.org/lang_expr.html for an explanation.
operators = {
"exact": "= %s",
"iexact": "LIKE %s ESCAPE '\\'",
"contains": "LIKE %s ESCAPE '\\'",
"icontains": "LIKE %s ESCAPE '\\'",
"regex": "REGEXP %s",
"iregex": "REGEXP '(?i)' || %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"startswith": "LIKE %s ESCAPE '\\'",
"endswith": "LIKE %s ESCAPE '\\'",
"istartswith": "LIKE %s ESCAPE '\\'",
"iendswith": "LIKE %s ESCAPE '\\'",
'exact': '= %s',
'iexact': "LIKE %s ESCAPE '\\'",
'contains': "LIKE %s ESCAPE '\\'",
'icontains': "LIKE %s ESCAPE '\\'",
'regex': 'REGEXP %s',
'iregex': "REGEXP '(?i)' || %s",
'gt': '> %s',
'gte': '>= %s',
'lt': '< %s',
'lte': '<= %s',
'startswith': "LIKE %s ESCAPE '\\'",
'endswith': "LIKE %s ESCAPE '\\'",
'istartswith': "LIKE %s ESCAPE '\\'",
'iendswith': "LIKE %s ESCAPE '\\'",
}
# The patterns below are used to generate SQL pattern lookup clauses when
@@ -155,12 +157,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# the LIKE operator.
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
pattern_ops = {
"contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
"icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
"startswith": r"LIKE {} || '%%' ESCAPE '\'",
"istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
"endswith": r"LIKE '%%' || {} ESCAPE '\'",
"iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
'contains': r"LIKE '%%' || {} || '%%' ESCAPE '\'",
'icontains': r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
'startswith': r"LIKE {} || '%%' ESCAPE '\'",
'istartswith': r"LIKE UPPER({}) || '%%' ESCAPE '\'",
'endswith': r"LIKE '%%' || {} ESCAPE '\'",
'iendswith': r"LIKE '%%' || UPPER({}) ESCAPE '\'",
}
Database = Database
@@ -174,15 +176,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def get_connection_params(self):
settings_dict = self.settings_dict
if not settings_dict["NAME"]:
if not settings_dict['NAME']:
raise ImproperlyConfigured(
"settings.DATABASES is improperly configured. "
"Please supply the NAME value."
)
"Please supply the NAME value.")
kwargs = {
"database": settings_dict["NAME"],
"detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
**settings_dict["OPTIONS"],
# TODO: Remove str() when dropping support for PY36.
# https://bugs.python.org/issue33496
'database': str(settings_dict['NAME']),
'detect_types': Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
**settings_dict['OPTIONS'],
}
# Always allow the underlying SQLite connection to be shareable
# between multiple threads. The safe-guarding will be handled at a
@@ -190,103 +193,78 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# property. This is necessary as the shareability is disabled by
# default in pysqlite and it cannot be changed once a connection is
# opened.
if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
if 'check_same_thread' in kwargs and kwargs['check_same_thread']:
warnings.warn(
"The `check_same_thread` option was provided and set to "
"True. It will be overridden with False. Use the "
"`DatabaseWrapper.allow_thread_sharing` property instead "
"for controlling thread shareability.",
RuntimeWarning,
'The `check_same_thread` option was provided and set to '
'True. It will be overridden with False. Use the '
'`DatabaseWrapper.allow_thread_sharing` property instead '
'for controlling thread shareability.',
RuntimeWarning
)
kwargs.update({"check_same_thread": False, "uri": True})
kwargs.update({'check_same_thread': False, 'uri': True})
return kwargs
@async_unsafe
def get_new_connection(self, conn_params):
conn = Database.connect(**conn_params)
create_deterministic_function = functools.partial(
conn.create_function,
deterministic=True,
)
create_deterministic_function(
"django_date_extract", 2, _sqlite_datetime_extract
)
create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc)
create_deterministic_function(
"django_datetime_cast_date", 3, _sqlite_datetime_cast_date
)
create_deterministic_function(
"django_datetime_cast_time", 3, _sqlite_datetime_cast_time
)
create_deterministic_function(
"django_datetime_extract", 4, _sqlite_datetime_extract
)
create_deterministic_function(
"django_datetime_trunc", 4, _sqlite_datetime_trunc
)
create_deterministic_function("django_time_extract", 2, _sqlite_time_extract)
create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc)
create_deterministic_function("django_time_diff", 2, _sqlite_time_diff)
create_deterministic_function(
"django_timestamp_diff", 2, _sqlite_timestamp_diff
)
create_deterministic_function(
"django_format_dtdelta", 3, _sqlite_format_dtdelta
)
create_deterministic_function("regexp", 2, _sqlite_regexp)
create_deterministic_function("ACOS", 1, none_guard(math.acos))
create_deterministic_function("ASIN", 1, none_guard(math.asin))
create_deterministic_function("ATAN", 1, none_guard(math.atan))
create_deterministic_function("ATAN2", 2, none_guard(math.atan2))
create_deterministic_function("BITXOR", 2, none_guard(operator.xor))
create_deterministic_function("CEILING", 1, none_guard(math.ceil))
create_deterministic_function("COS", 1, none_guard(math.cos))
create_deterministic_function("COT", 1, none_guard(lambda x: 1 / math.tan(x)))
create_deterministic_function("DEGREES", 1, none_guard(math.degrees))
create_deterministic_function("EXP", 1, none_guard(math.exp))
create_deterministic_function("FLOOR", 1, none_guard(math.floor))
create_deterministic_function("LN", 1, none_guard(math.log))
create_deterministic_function("LOG", 2, none_guard(lambda x, y: math.log(y, x)))
create_deterministic_function("LPAD", 3, _sqlite_lpad)
create_deterministic_function(
"MD5", 1, none_guard(lambda x: hashlib.md5(x.encode()).hexdigest())
)
create_deterministic_function("MOD", 2, none_guard(math.fmod))
create_deterministic_function("PI", 0, lambda: math.pi)
create_deterministic_function("POWER", 2, none_guard(operator.pow))
create_deterministic_function("RADIANS", 1, none_guard(math.radians))
create_deterministic_function("REPEAT", 2, none_guard(operator.mul))
create_deterministic_function("REVERSE", 1, none_guard(lambda x: x[::-1]))
create_deterministic_function("RPAD", 3, _sqlite_rpad)
create_deterministic_function(
"SHA1", 1, none_guard(lambda x: hashlib.sha1(x.encode()).hexdigest())
)
create_deterministic_function(
"SHA224", 1, none_guard(lambda x: hashlib.sha224(x.encode()).hexdigest())
)
create_deterministic_function(
"SHA256", 1, none_guard(lambda x: hashlib.sha256(x.encode()).hexdigest())
)
create_deterministic_function(
"SHA384", 1, none_guard(lambda x: hashlib.sha384(x.encode()).hexdigest())
)
create_deterministic_function(
"SHA512", 1, none_guard(lambda x: hashlib.sha512(x.encode()).hexdigest())
)
create_deterministic_function(
"SIGN", 1, none_guard(lambda x: (x > 0) - (x < 0))
)
create_deterministic_function("SIN", 1, none_guard(math.sin))
create_deterministic_function("SQRT", 1, none_guard(math.sqrt))
create_deterministic_function("TAN", 1, none_guard(math.tan))
if PY38:
create_deterministic_function = functools.partial(
conn.create_function,
deterministic=True,
)
else:
create_deterministic_function = conn.create_function
create_deterministic_function('django_date_extract', 2, _sqlite_datetime_extract)
create_deterministic_function('django_date_trunc', 4, _sqlite_date_trunc)
create_deterministic_function('django_datetime_cast_date', 3, _sqlite_datetime_cast_date)
create_deterministic_function('django_datetime_cast_time', 3, _sqlite_datetime_cast_time)
create_deterministic_function('django_datetime_extract', 4, _sqlite_datetime_extract)
create_deterministic_function('django_datetime_trunc', 4, _sqlite_datetime_trunc)
create_deterministic_function('django_time_extract', 2, _sqlite_time_extract)
create_deterministic_function('django_time_trunc', 4, _sqlite_time_trunc)
create_deterministic_function('django_time_diff', 2, _sqlite_time_diff)
create_deterministic_function('django_timestamp_diff', 2, _sqlite_timestamp_diff)
create_deterministic_function('django_format_dtdelta', 3, _sqlite_format_dtdelta)
create_deterministic_function('regexp', 2, _sqlite_regexp)
create_deterministic_function('ACOS', 1, none_guard(math.acos))
create_deterministic_function('ASIN', 1, none_guard(math.asin))
create_deterministic_function('ATAN', 1, none_guard(math.atan))
create_deterministic_function('ATAN2', 2, none_guard(math.atan2))
create_deterministic_function('BITXOR', 2, none_guard(operator.xor))
create_deterministic_function('CEILING', 1, none_guard(math.ceil))
create_deterministic_function('COS', 1, none_guard(math.cos))
create_deterministic_function('COT', 1, none_guard(lambda x: 1 / math.tan(x)))
create_deterministic_function('DEGREES', 1, none_guard(math.degrees))
create_deterministic_function('EXP', 1, none_guard(math.exp))
create_deterministic_function('FLOOR', 1, none_guard(math.floor))
create_deterministic_function('LN', 1, none_guard(math.log))
create_deterministic_function('LOG', 2, none_guard(lambda x, y: math.log(y, x)))
create_deterministic_function('LPAD', 3, _sqlite_lpad)
create_deterministic_function('MD5', 1, none_guard(lambda x: hashlib.md5(x.encode()).hexdigest()))
create_deterministic_function('MOD', 2, none_guard(math.fmod))
create_deterministic_function('PI', 0, lambda: math.pi)
create_deterministic_function('POWER', 2, none_guard(operator.pow))
create_deterministic_function('RADIANS', 1, none_guard(math.radians))
create_deterministic_function('REPEAT', 2, none_guard(operator.mul))
create_deterministic_function('REVERSE', 1, none_guard(lambda x: x[::-1]))
create_deterministic_function('RPAD', 3, _sqlite_rpad)
create_deterministic_function('SHA1', 1, none_guard(lambda x: hashlib.sha1(x.encode()).hexdigest()))
create_deterministic_function('SHA224', 1, none_guard(lambda x: hashlib.sha224(x.encode()).hexdigest()))
create_deterministic_function('SHA256', 1, none_guard(lambda x: hashlib.sha256(x.encode()).hexdigest()))
create_deterministic_function('SHA384', 1, none_guard(lambda x: hashlib.sha384(x.encode()).hexdigest()))
create_deterministic_function('SHA512', 1, none_guard(lambda x: hashlib.sha512(x.encode()).hexdigest()))
create_deterministic_function('SIGN', 1, none_guard(lambda x: (x > 0) - (x < 0)))
create_deterministic_function('SIN', 1, none_guard(math.sin))
create_deterministic_function('SQRT', 1, none_guard(math.sqrt))
create_deterministic_function('TAN', 1, none_guard(math.tan))
# Don't use the built-in RANDOM() function because it returns a value
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
conn.create_function("RAND", 0, random.random)
conn.create_aggregate("STDDEV_POP", 1, list_aggregate(statistics.pstdev))
conn.create_aggregate("STDDEV_SAMP", 1, list_aggregate(statistics.stdev))
conn.create_aggregate("VAR_POP", 1, list_aggregate(statistics.pvariance))
conn.create_aggregate("VAR_SAMP", 1, list_aggregate(statistics.variance))
conn.execute("PRAGMA foreign_keys = ON")
# in the range [2^63, 2^63 - 1] instead of [0, 1).
conn.create_function('RAND', 0, random.random)
conn.create_aggregate('STDDEV_POP', 1, list_aggregate(statistics.pstdev))
conn.create_aggregate('STDDEV_SAMP', 1, list_aggregate(statistics.stdev))
conn.create_aggregate('VAR_POP', 1, list_aggregate(statistics.pvariance))
conn.create_aggregate('VAR_SAMP', 1, list_aggregate(statistics.variance))
conn.execute('PRAGMA foreign_keys = ON')
return conn
def init_connection_state(self):
@@ -318,7 +296,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
else:
# sqlite3's internal default is ''. It's different from None.
# See Modules/_sqlite/connection.c.
level = ""
level = ''
# 'isolation_level' is a misleading API.
# SQLite always runs at the SERIALIZABLE isolation level.
with self.wrap_database_errors:
@@ -326,16 +304,16 @@ class DatabaseWrapper(BaseDatabaseWrapper):
def disable_constraint_checking(self):
with self.cursor() as cursor:
cursor.execute("PRAGMA foreign_keys = OFF")
cursor.execute('PRAGMA foreign_keys = OFF')
# Foreign key constraints cannot be turned off while in a multi-
# statement transaction. Fetch the current state of the pragma
# to determine if constraints are effectively disabled.
enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
enabled = cursor.execute('PRAGMA foreign_keys').fetchone()[0]
return not bool(enabled)
def enable_constraint_checking(self):
with self.cursor() as cursor:
cursor.execute("PRAGMA foreign_keys = ON")
cursor.execute('PRAGMA foreign_keys = ON')
def check_constraints(self, table_names=None):
"""
@@ -348,32 +326,24 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if self.features.supports_pragma_foreign_key_check:
with self.cursor() as cursor:
if table_names is None:
violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
violations = cursor.execute('PRAGMA foreign_key_check').fetchall()
else:
violations = chain.from_iterable(
cursor.execute(
"PRAGMA foreign_key_check(%s)"
'PRAGMA foreign_key_check(%s)'
% self.ops.quote_name(table_name)
).fetchall()
for table_name in table_names
)
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
for (
table_name,
rowid,
referenced_table_name,
foreign_key_index,
) in violations:
for table_name, rowid, referenced_table_name, foreign_key_index in violations:
foreign_key = cursor.execute(
"PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name)
'PRAGMA foreign_key_list(%s)' % self.ops.quote_name(table_name)
).fetchall()[foreign_key_index]
column_name, referenced_column_name = foreign_key[3:5]
primary_key_column_name = self.introspection.get_primary_key_column(
cursor, table_name
)
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
primary_key_value, bad_value = cursor.execute(
"SELECT %s, %s FROM %s WHERE rowid = %%s"
% (
'SELECT %s, %s FROM %s WHERE rowid = %%s' % (
self.ops.quote_name(primary_key_column_name),
self.ops.quote_name(column_name),
self.ops.quote_name(table_name),
@@ -383,15 +353,9 @@ class DatabaseWrapper(BaseDatabaseWrapper):
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s."
% (
table_name,
primary_key_value,
table_name,
column_name,
bad_value,
referenced_table_name,
referenced_column_name,
"does not have a corresponding value in %s.%s." % (
table_name, primary_key_value, table_name, column_name,
bad_value, referenced_table_name, referenced_column_name
)
)
else:
@@ -399,17 +363,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
if table_names is None:
table_names = self.introspection.table_names(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(
cursor, table_name
)
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for (
column_name,
referenced_table_name,
referenced_column_name,
) in key_columns:
for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute(
"""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
@@ -418,29 +376,18 @@ class DatabaseWrapper(BaseDatabaseWrapper):
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
"""
% (
primary_key_column_name,
column_name,
table_name,
referenced_table_name,
column_name,
referenced_column_name,
column_name,
referenced_column_name,
primary_key_column_name, column_name, table_name,
referenced_table_name, column_name, referenced_column_name,
column_name, referenced_column_name,
)
)
for bad_row in cursor.fetchall():
raise IntegrityError(
"The row in table '%s' with primary key '%s' has an "
"invalid foreign key: %s.%s contains a value '%s' that "
"does not have a corresponding value in %s.%s."
% (
table_name,
bad_row[0],
table_name,
column_name,
bad_row[1],
referenced_table_name,
referenced_column_name,
"does not have a corresponding value in %s.%s." % (
table_name, bad_row[0], table_name, column_name,
bad_row[1], referenced_table_name, referenced_column_name,
)
)
@@ -457,10 +404,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.cursor().execute("BEGIN")
def is_in_memory_db(self):
return self.creation.is_in_memory_db(self.settings_dict["NAME"])
return self.creation.is_in_memory_db(self.settings_dict['NAME'])
FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
FORMAT_QMARK_REGEX = _lazy_re_compile(r'(?<!%)%s')
class SQLiteCursorWrapper(Database.Cursor):
@@ -469,7 +416,6 @@ class SQLiteCursorWrapper(Database.Cursor):
This fixes it -- but note that if you want to use a literal "%s" in a query,
you'll need to use "%%s".
"""
def execute(self, query, params=None):
if params is None:
return Database.Cursor.execute(self, query)
@@ -481,7 +427,7 @@ class SQLiteCursorWrapper(Database.Cursor):
return Database.Cursor.executemany(self, query, param_list)
def convert_query(self, query):
return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
return FORMAT_QMARK_REGEX.sub('?', query).replace('%%', '%')
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
@@ -492,14 +438,17 @@ def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
except (TypeError, ValueError):
return None
if conn_tzname:
dt = dt.replace(tzinfo=timezone_constructor(conn_tzname))
dt = dt.replace(tzinfo=pytz.timezone(conn_tzname))
if tzname is not None and tzname != conn_tzname:
tzname, sign, offset = backend_utils.split_tzname_delta(tzname)
if offset:
hours, minutes = offset.split(":")
offset_delta = datetime.timedelta(hours=int(hours), minutes=int(minutes))
dt += offset_delta if sign == "+" else -offset_delta
dt = timezone.localtime(dt, timezone_constructor(tzname))
sign_index = tzname.find('+') + tzname.find('-') + 1
if sign_index > -1:
sign = tzname[sign_index]
tzname, offset = tzname.split(sign)
if offset:
hours, minutes = offset.split(':')
offset_delta = datetime.timedelta(hours=int(hours), minutes=int(minutes))
dt += offset_delta if sign == '+' else -offset_delta
dt = timezone.localtime(dt, pytz.timezone(tzname))
return dt
@@ -507,17 +456,17 @@ def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "year":
if lookup_type == 'year':
return "%i-01-01" % dt.year
elif lookup_type == "quarter":
elif lookup_type == 'quarter':
month_in_quarter = dt.month - (dt.month - 1) % 3
return "%i-%02i-01" % (dt.year, month_in_quarter)
elif lookup_type == "month":
return '%i-%02i-01' % (dt.year, month_in_quarter)
elif lookup_type == 'month':
return "%i-%02i-01" % (dt.year, dt.month)
elif lookup_type == "week":
elif lookup_type == 'week':
dt = dt - datetime.timedelta(days=dt.weekday())
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
elif lookup_type == "day":
elif lookup_type == 'day':
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
@@ -532,11 +481,11 @@ def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
return None
else:
dt = dt_parsed
if lookup_type == "hour":
if lookup_type == 'hour':
return "%02i:00:00" % dt.hour
elif lookup_type == "minute":
elif lookup_type == 'minute':
return "%02i:%02i:00" % (dt.hour, dt.minute)
elif lookup_type == "second":
elif lookup_type == 'second':
return "%02i:%02i:%02i" % (dt.hour, dt.minute, dt.second)
@@ -558,15 +507,15 @@ def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "week_day":
if lookup_type == 'week_day':
return (dt.isoweekday() % 7) + 1
elif lookup_type == "iso_week_day":
elif lookup_type == 'iso_week_day':
return dt.isoweekday()
elif lookup_type == "week":
elif lookup_type == 'week':
return dt.isocalendar()[1]
elif lookup_type == "quarter":
elif lookup_type == 'quarter':
return math.ceil(dt.month / 3)
elif lookup_type == "iso_year":
elif lookup_type == 'iso_year':
return dt.isocalendar()[0]
else:
return getattr(dt, lookup_type)
@@ -576,37 +525,24 @@ def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "year":
if lookup_type == 'year':
return "%i-01-01 00:00:00" % dt.year
elif lookup_type == "quarter":
elif lookup_type == 'quarter':
month_in_quarter = dt.month - (dt.month - 1) % 3
return "%i-%02i-01 00:00:00" % (dt.year, month_in_quarter)
elif lookup_type == "month":
return '%i-%02i-01 00:00:00' % (dt.year, month_in_quarter)
elif lookup_type == 'month':
return "%i-%02i-01 00:00:00" % (dt.year, dt.month)
elif lookup_type == "week":
elif lookup_type == 'week':
dt = dt - datetime.timedelta(days=dt.weekday())
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
elif lookup_type == "day":
elif lookup_type == 'day':
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
elif lookup_type == "hour":
elif lookup_type == 'hour':
return "%i-%02i-%02i %02i:00:00" % (dt.year, dt.month, dt.day, dt.hour)
elif lookup_type == "minute":
return "%i-%02i-%02i %02i:%02i:00" % (
dt.year,
dt.month,
dt.day,
dt.hour,
dt.minute,
)
elif lookup_type == "second":
return "%i-%02i-%02i %02i:%02i:%02i" % (
dt.year,
dt.month,
dt.day,
dt.hour,
dt.minute,
dt.second,
)
elif lookup_type == 'minute':
return "%i-%02i-%02i %02i:%02i:00" % (dt.year, dt.month, dt.day, dt.hour, dt.minute)
elif lookup_type == 'second':
return "%i-%02i-%02i %02i:%02i:%02i" % (dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second)
def _sqlite_time_extract(lookup_type, dt):
@@ -619,40 +555,25 @@ def _sqlite_time_extract(lookup_type, dt):
return getattr(dt, lookup_type)
def _sqlite_prepare_dtdelta_param(conn, param):
if conn in ["+", "-"]:
if isinstance(param, int):
return datetime.timedelta(0, 0, param)
else:
return backend_utils.typecast_timestamp(param)
return param
@none_guard
def _sqlite_format_dtdelta(conn, lhs, rhs):
"""
LHS and RHS can be either:
- An integer number of microseconds
- A string representing a datetime
- A scalar value, e.g. float
"""
conn = conn.strip()
try:
real_lhs = _sqlite_prepare_dtdelta_param(conn, lhs)
real_rhs = _sqlite_prepare_dtdelta_param(conn, rhs)
real_lhs = datetime.timedelta(0, 0, lhs) if isinstance(lhs, int) else backend_utils.typecast_timestamp(lhs)
real_rhs = datetime.timedelta(0, 0, rhs) if isinstance(rhs, int) else backend_utils.typecast_timestamp(rhs)
if conn.strip() == '+':
out = real_lhs + real_rhs
else:
out = real_lhs - real_rhs
except (ValueError, TypeError):
return None
if conn == "+":
# typecast_timestamp returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
out = str(real_lhs + real_rhs)
elif conn == "-":
out = str(real_lhs - real_rhs)
elif conn == "*":
out = real_lhs * real_rhs
else:
out = real_lhs / real_rhs
return out
# typecast_timestamp returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
return str(out)
@none_guard
@@ -660,14 +581,14 @@ def _sqlite_time_diff(lhs, rhs):
left = backend_utils.typecast_time(lhs)
right = backend_utils.typecast_time(rhs)
return (
(left.hour * 60 * 60 * 1000000)
+ (left.minute * 60 * 1000000)
+ (left.second * 1000000)
+ (left.microsecond)
- (right.hour * 60 * 60 * 1000000)
- (right.minute * 60 * 1000000)
- (right.second * 1000000)
- (right.microsecond)
(left.hour * 60 * 60 * 1000000) +
(left.minute * 60 * 1000000) +
(left.second * 1000000) +
(left.microsecond) -
(right.hour * 60 * 60 * 1000000) -
(right.minute * 60 * 1000000) -
(right.second * 1000000) -
(right.microsecond)
)
@@ -687,7 +608,7 @@ def _sqlite_regexp(re_pattern, re_string):
def _sqlite_lpad(text, length, fill_text):
if len(text) >= length:
return text[:length]
return (fill_text * length)[: length - len(text)] + text
return (fill_text * length)[:length - len(text)] + text
@none_guard
@@ -2,9 +2,15 @@ from django.db.backends.base.client import BaseDatabaseClient
class DatabaseClient(BaseDatabaseClient):
executable_name = "sqlite3"
executable_name = 'sqlite3'
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name, settings_dict["NAME"], *parameters]
args = [
cls.executable_name,
# TODO: Remove str() when dropping support for PY37. args
# parameter accepts path-like objects on Windows since Python 3.8.
str(settings_dict['NAME']),
*parameters,
]
return args, None
@@ -7,16 +7,17 @@ from django.db.backends.base.creation import BaseDatabaseCreation
class DatabaseCreation(BaseDatabaseCreation):
@staticmethod
def is_in_memory_db(database_name):
return not isinstance(database_name, Path) and (
database_name == ":memory:" or "mode=memory" in database_name
database_name == ':memory:' or 'mode=memory' in database_name
)
def _get_test_db_name(self):
test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:"
if test_database_name == ":memory:":
return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias
test_database_name = self.connection.settings_dict['TEST']['NAME'] or ':memory:'
if test_database_name == ':memory:':
return 'file:memorydb_%s?mode=memory&cache=shared' % self.connection.alias
return test_database_name
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
@@ -27,39 +28,38 @@ class DatabaseCreation(BaseDatabaseCreation):
if not self.is_in_memory_db(test_database_name):
# Erase the old test database
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (self._get_database_display_str(verbosity, test_database_name),)
)
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, test_database_name),
))
if os.access(test_database_name, os.F_OK):
if not autoclobber:
confirm = input(
"Type 'yes' if you would like to try deleting the test "
"database '%s', or 'no' to cancel: " % test_database_name
)
if autoclobber or confirm == "yes":
if autoclobber or confirm == 'yes':
try:
os.remove(test_database_name)
except Exception as e:
self.log("Got an error deleting the old test database: %s" % e)
self.log('Got an error deleting the old test database: %s' % e)
sys.exit(2)
else:
self.log("Tests cancelled.")
self.log('Tests cancelled.')
sys.exit(1)
return test_database_name
def get_test_db_clone_settings(self, suffix):
orig_settings_dict = self.connection.settings_dict
source_database_name = orig_settings_dict["NAME"]
source_database_name = orig_settings_dict['NAME']
if self.is_in_memory_db(source_database_name):
return orig_settings_dict
else:
root, ext = os.path.splitext(orig_settings_dict["NAME"])
return {**orig_settings_dict, "NAME": "{}_{}{}".format(root, suffix, ext)}
root, ext = os.path.splitext(orig_settings_dict['NAME'])
return {**orig_settings_dict, 'NAME': '{}_{}.{}'.format(root, suffix, ext)}
def _clone_test_db(self, suffix, verbosity, keepdb=False):
source_database_name = self.connection.settings_dict["NAME"]
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
source_database_name = self.connection.settings_dict['NAME']
target_database_name = self.get_test_db_clone_settings(suffix)['NAME']
# Forking automatically makes a copy of an in-memory database.
if not self.is_in_memory_db(source_database_name):
# Erase the old test database
@@ -67,23 +67,18 @@ class DatabaseCreation(BaseDatabaseCreation):
if keepdb:
return
if verbosity >= 1:
self.log(
"Destroying old test database for alias %s..."
% (
self._get_database_display_str(
verbosity, target_database_name
),
)
)
self.log('Destroying old test database for alias %s...' % (
self._get_database_display_str(verbosity, target_database_name),
))
try:
os.remove(target_database_name)
except Exception as e:
self.log("Got an error deleting the old test database: %s" % e)
self.log('Got an error deleting the old test database: %s' % e)
sys.exit(2)
try:
shutil.copy(source_database_name, target_database_name)
except Exception as e:
self.log("Got an error cloning the test database: %s" % e)
self.log('Got an error cloning the test database: %s' % e)
sys.exit(2)
def _destroy_test_db(self, test_database_name, verbosity):
@@ -100,7 +95,7 @@ class DatabaseCreation(BaseDatabaseCreation):
TEST NAME. See https://www.sqlite.org/inmemorydb.html
"""
test_database_name = self._get_test_db_name()
sig = [self.connection.settings_dict["NAME"]]
sig = [self.connection.settings_dict['NAME']]
if self.is_in_memory_db(test_database_name):
sig.append(self.connection.alias)
else:
@@ -45,82 +45,58 @@ class DatabaseFeatures(BaseDatabaseFeatures):
order_by_nulls_first = True
supports_json_field_contains = False
test_collations = {
"ci": "nocase",
"cs": "binary",
"non_default": "nocase",
'ci': 'nocase',
'cs': 'binary',
'non_default': 'nocase',
}
@cached_property
def django_test_skips(self):
skips = {
"SQLite stores values rounded to 15 significant digits.": {
"model_fields.test_decimalfield.DecimalFieldTests."
"test_fetch_from_db_without_float_rounding",
'SQLite stores values rounded to 15 significant digits.': {
'model_fields.test_decimalfield.DecimalFieldTests.test_fetch_from_db_without_float_rounding',
},
"SQLite naively remakes the table on field alteration.": {
"schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops",
"schema.tests.SchemaTests.test_unique_and_reverse_m2m",
"schema.tests.SchemaTests."
"test_alter_field_default_doesnt_perform_queries",
"schema.tests.SchemaTests."
"test_rename_column_renames_deferred_sql_references",
'SQLite naively remakes the table on field alteration.': {
'schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops',
'schema.tests.SchemaTests.test_unique_and_reverse_m2m',
'schema.tests.SchemaTests.test_alter_field_default_doesnt_perform_queries',
'schema.tests.SchemaTests.test_rename_column_renames_deferred_sql_references',
},
"SQLite doesn't have a constraint.": {
"model_fields.test_integerfield.PositiveIntegerFieldTests."
"test_negative_values",
},
"SQLite doesn't support negative precision for ROUND().": {
"db_functions.math.test_round.RoundTests."
"test_null_with_negative_precision",
"db_functions.math.test_round.RoundTests."
"test_decimal_with_negative_precision",
"db_functions.math.test_round.RoundTests."
"test_float_with_negative_precision",
"db_functions.math.test_round.RoundTests."
"test_integer_with_negative_precision",
'model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values',
},
}
if Database.sqlite_version_info < (3, 27):
skips.update(
{
"Nondeterministic failure on SQLite < 3.27.": {
"expressions_window.tests.WindowFunctionTests."
"test_subquery_row_range_rank",
},
}
)
skips.update({
'Nondeterministic failure on SQLite < 3.27.': {
'expressions_window.tests.WindowFunctionTests.test_subquery_row_range_rank',
},
})
if self.connection.is_in_memory_db():
skips.update(
{
"the sqlite backend's close() method is a no-op when using an "
"in-memory database": {
"servers.test_liveserverthread.LiveServerThreadTest."
"test_closes_connections",
"servers.tests.LiveServerTestCloseConnectionTest."
"test_closes_connections",
},
}
)
skips.update({
"the sqlite backend's close() method is a no-op when using an "
"in-memory database": {
'servers.test_liveserverthread.LiveServerThreadTest.test_closes_connections',
},
})
return skips
@cached_property
def supports_atomic_references_rename(self):
# SQLite 3.28.0 bundled with MacOS 10.15 does not support renaming
# references atomically.
if platform.mac_ver()[0].startswith(
"10.15."
) and Database.sqlite_version_info == (3, 28, 0):
if platform.mac_ver()[0].startswith('10.15.') and Database.sqlite_version_info == (3, 28, 0):
return False
return Database.sqlite_version_info >= (3, 26, 0)
@cached_property
def introspected_field_types(self):
return {
return{
**super().introspected_field_types,
"BigAutoField": "AutoField",
"DurationField": "BigIntegerField",
"GenericIPAddressField": "CharField",
"SmallAutoField": "AutoField",
'BigAutoField': 'AutoField',
'DurationField': 'BigIntegerField',
'GenericIPAddressField': 'CharField',
'SmallAutoField': 'AutoField',
}
@cached_property
@@ -133,13 +109,5 @@ class DatabaseFeatures(BaseDatabaseFeatures):
return False
return True
can_introspect_json_field = property(operator.attrgetter("supports_json_field"))
has_json_object_function = property(operator.attrgetter("supports_json_field"))
@cached_property
def can_return_columns_from_insert(self):
return Database.sqlite_version_info >= (3, 35)
can_return_rows_from_bulk_insert = property(
operator.attrgetter("can_return_columns_from_insert")
)
can_introspect_json_field = property(operator.attrgetter('supports_json_field'))
has_json_object_function = property(operator.attrgetter('supports_json_field'))
@@ -3,21 +3,19 @@ from collections import namedtuple
import sqlparse
from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo
from django.db.backends.base.introspection import (
BaseDatabaseIntrospection, FieldInfo as BaseFieldInfo, TableInfo,
)
from django.db.models import Index
from django.utils.regex_helper import _lazy_re_compile
FieldInfo = namedtuple(
"FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
)
FieldInfo = namedtuple('FieldInfo', BaseFieldInfo._fields + ('pk', 'has_json_constraint'))
field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
field_size_re = _lazy_re_compile(r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$')
def get_field_size(name):
"""Extract the size number from a "varchar(11)" type name"""
""" Extract the size number from a "varchar(11)" type name """
m = field_size_re.search(name)
return int(m[1]) if m else None
@@ -30,29 +28,29 @@ class FlexibleFieldLookupDict:
# entries here because SQLite allows for anything and doesn't normalize the
# field type; it uses whatever was given.
base_data_types_reverse = {
"bool": "BooleanField",
"boolean": "BooleanField",
"smallint": "SmallIntegerField",
"smallint unsigned": "PositiveSmallIntegerField",
"smallinteger": "SmallIntegerField",
"int": "IntegerField",
"integer": "IntegerField",
"bigint": "BigIntegerField",
"integer unsigned": "PositiveIntegerField",
"bigint unsigned": "PositiveBigIntegerField",
"decimal": "DecimalField",
"real": "FloatField",
"text": "TextField",
"char": "CharField",
"varchar": "CharField",
"blob": "BinaryField",
"date": "DateField",
"datetime": "DateTimeField",
"time": "TimeField",
'bool': 'BooleanField',
'boolean': 'BooleanField',
'smallint': 'SmallIntegerField',
'smallint unsigned': 'PositiveSmallIntegerField',
'smallinteger': 'SmallIntegerField',
'int': 'IntegerField',
'integer': 'IntegerField',
'bigint': 'BigIntegerField',
'integer unsigned': 'PositiveIntegerField',
'bigint unsigned': 'PositiveBigIntegerField',
'decimal': 'DecimalField',
'real': 'FloatField',
'text': 'TextField',
'char': 'CharField',
'varchar': 'CharField',
'blob': 'BinaryField',
'date': 'DateField',
'datetime': 'DateTimeField',
'time': 'TimeField',
}
def __getitem__(self, key):
key = key.lower().split("(", 1)[0].strip()
key = key.lower().split('(', 1)[0].strip()
return self.base_data_types_reverse[key]
@@ -61,28 +59,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
def get_field_type(self, data_type, description):
field_type = super().get_field_type(data_type, description)
if description.pk and field_type in {
"BigIntegerField",
"IntegerField",
"SmallIntegerField",
}:
if description.pk and field_type in {'BigIntegerField', 'IntegerField', 'SmallIntegerField'}:
# No support for BigAutoField or SmallAutoField as SQLite treats
# all integer primary keys as signed 64-bit integers.
return "AutoField"
return 'AutoField'
if description.has_json_constraint:
return "JSONField"
return 'JSONField'
return field_type
def get_table_list(self, cursor):
"""Return a list of table and view names in the current database."""
# Skip the sqlite_sequence system table used for autoincrement key
# generation.
cursor.execute(
"""
cursor.execute("""
SELECT name, type FROM sqlite_master
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
ORDER BY name"""
)
ORDER BY name""")
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
def get_table_description(self, cursor, table_name):
@@ -90,9 +82,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
Return a description of the table with the DB-API cursor.description
interface.
"""
cursor.execute(
"PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
)
cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(table_name))
table_info = cursor.fetchall()
collations = self._get_column_collations(cursor, table_name)
json_columns = set()
@@ -100,39 +90,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
for line in table_info:
column = line[1]
json_constraint_sql = '%%json_valid("%s")%%' % column
has_json_constraint = cursor.execute(
"""
has_json_constraint = cursor.execute("""
SELECT sql
FROM sqlite_master
WHERE
type = 'table' AND
name = %s AND
sql LIKE %s
""",
[table_name, json_constraint_sql],
).fetchone()
""", [table_name, json_constraint_sql]).fetchone()
if has_json_constraint:
json_columns.add(column)
return [
FieldInfo(
name,
data_type,
None,
get_field_size(data_type),
None,
None,
not notnull,
default,
collations.get(name),
pk == 1,
name in json_columns,
name, data_type, None, get_field_size(data_type), None, None,
not notnull, default, collations.get(name), pk == 1, name in json_columns
)
for cid, name, data_type, notnull, default, pk in table_info
]
def get_sequences(self, cursor, table_name, table_fields=()):
pk_col = self.get_primary_key_column(cursor, table_name)
return [{"table": table_name, "column": pk_col}]
return [{'table': table_name, 'column': pk_col}]
def get_relations(self, cursor, table_name):
"""
@@ -146,18 +124,18 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
cursor.execute(
"SELECT sql, type FROM sqlite_master "
"WHERE tbl_name = %s AND type IN ('table', 'view')",
[table_name],
[table_name]
)
create_sql, table_type = cursor.fetchone()
if table_type == "view":
if table_type == 'view':
# It might be a view, then no results will be returned
return relations
results = create_sql[create_sql.index("(") + 1 : create_sql.rindex(")")]
results = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
# Walk through and look for references to other tables. SQLite doesn't
# really have enforced references, but since it echoes out the SQL used
# to create the table we can look for REFERENCES statements used there.
for field_desc in results.split(","):
for field_desc in results.split(','):
field_desc = field_desc.strip()
if field_desc.startswith("UNIQUE"):
continue
@@ -169,7 +147,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
if field_desc.startswith("FOREIGN KEY"):
# Find name of the target FK field
m = re.match(r"FOREIGN KEY\s*\(([^\)]*)\).*", field_desc, re.I)
m = re.match(r'FOREIGN KEY\s*\(([^\)]*)\).*', field_desc, re.I)
field_name = m[1].strip('"')
else:
field_name = field_desc.split()[0].strip('"')
@@ -177,15 +155,15 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
result = cursor.fetchall()[0]
other_table_results = result[0].strip()
li, ri = other_table_results.index("("), other_table_results.rindex(")")
other_table_results = other_table_results[li + 1 : ri]
li, ri = other_table_results.index('('), other_table_results.rindex(')')
other_table_results = other_table_results[li + 1:ri]
for other_desc in other_table_results.split(","):
for other_desc in other_table_results.split(','):
other_desc = other_desc.strip()
if other_desc.startswith("UNIQUE"):
if other_desc.startswith('UNIQUE'):
continue
other_name = other_desc.split(" ", 1)[0].strip('"')
other_name = other_desc.split(' ', 1)[0].strip('"')
if other_name == column:
relations[field_name] = (other_name, table)
break
@@ -200,17 +178,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
key_columns = []
# Schema for this table
cursor.execute(
"SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s",
[table_name, "table"],
)
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"])
results = cursor.fetchone()[0].strip()
results = results[results.index("(") + 1 : results.rindex(")")]
results = results[results.index('(') + 1:results.rindex(')')]
# Walk through and look for references to other tables. SQLite doesn't
# really have enforced references, but since it echoes out the SQL used
# to create the table we can look for REFERENCES statements used there.
for field_index, field_desc in enumerate(results.split(",")):
for field_index, field_desc in enumerate(results.split(',')):
field_desc = field_desc.strip()
if field_desc.startswith("UNIQUE"):
continue
@@ -219,9 +194,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
if not m:
continue
# This will append
# (column_name, referenced_table_name, referenced_column_name) to
# key_columns.
# This will append (column_name, referenced_table_name, referenced_column_name) to key_columns
key_columns.append(tuple(s.strip('"') for s in m.groups()))
return key_columns
@@ -232,40 +205,36 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
cursor.execute(
"SELECT sql, type FROM sqlite_master "
"WHERE tbl_name = %s AND type IN ('table', 'view')",
[table_name],
[table_name]
)
row = cursor.fetchone()
if row is None:
raise ValueError("Table %s does not exist" % table_name)
create_sql, table_type = row
if table_type == "view":
if table_type == 'view':
# Views don't have a primary key.
return None
fields_sql = create_sql[create_sql.index("(") + 1 : create_sql.rindex(")")]
for field_desc in fields_sql.split(","):
fields_sql = create_sql[create_sql.index('(') + 1:create_sql.rindex(')')]
for field_desc in fields_sql.split(','):
field_desc = field_desc.strip()
m = re.match(
r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc
)
m = re.match(r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc)
if m:
return m[1] if m[1] else m[2]
return None
def _get_foreign_key_constraints(self, cursor, table_name):
constraints = {}
cursor.execute(
"PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
)
cursor.execute('PRAGMA foreign_key_list(%s)' % self.connection.ops.quote_name(table_name))
for row in cursor.fetchall():
# Remaining on_update/on_delete/match values are of no interest.
id_, _, table, from_, to = row[:5]
constraints["fk_%d" % id_] = {
"columns": [from_],
"primary_key": False,
"unique": False,
"foreign_key": (table, to),
"check": False,
"index": False,
constraints['fk_%d' % id_] = {
'columns': [from_],
'primary_key': False,
'unique': False,
'foreign_key': (table, to),
'check': False,
'index': False,
}
return constraints
@@ -280,21 +249,19 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
check_columns = []
braces_deep = 0
for token in tokens:
if token.match(sqlparse.tokens.Punctuation, "("):
if token.match(sqlparse.tokens.Punctuation, '('):
braces_deep += 1
elif token.match(sqlparse.tokens.Punctuation, ")"):
elif token.match(sqlparse.tokens.Punctuation, ')'):
braces_deep -= 1
if braces_deep < 0:
# End of columns and constraints for table definition.
break
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ','):
# End of current column or constraint definition.
break
# Detect column or constraint definition by first token.
if is_constraint_definition is None:
is_constraint_definition = token.match(
sqlparse.tokens.Keyword, "CONSTRAINT"
)
is_constraint_definition = token.match(sqlparse.tokens.Keyword, 'CONSTRAINT')
if is_constraint_definition:
continue
if is_constraint_definition:
@@ -305,7 +272,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
constraint_name = token.value[1:-1]
# Start constraint columns parsing after UNIQUE keyword.
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
unique = True
unique_braces_deep = braces_deep
elif unique:
@@ -325,10 +292,10 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
field_name = token.value
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
field_name = token.value[1:-1]
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
unique_columns = [field_name]
# Start constraint columns parsing after CHECK keyword.
if token.match(sqlparse.tokens.Keyword, "CHECK"):
if token.match(sqlparse.tokens.Keyword, 'CHECK'):
check = True
check_braces_deep = braces_deep
elif check:
@@ -343,30 +310,22 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
if token.value[1:-1] in columns:
check_columns.append(token.value[1:-1])
unique_constraint = (
{
"unique": True,
"columns": unique_columns,
"primary_key": False,
"foreign_key": None,
"check": False,
"index": False,
}
if unique_columns
else None
)
check_constraint = (
{
"check": True,
"columns": check_columns,
"primary_key": False,
"unique": False,
"foreign_key": None,
"index": False,
}
if check_columns
else None
)
unique_constraint = {
'unique': True,
'columns': unique_columns,
'primary_key': False,
'foreign_key': None,
'check': False,
'index': False,
} if unique_columns else None
check_constraint = {
'check': True,
'columns': check_columns,
'primary_key': False,
'unique': False,
'foreign_key': None,
'index': False,
} if check_columns else None
return constraint_name, unique_constraint, check_constraint, token
def _parse_table_constraints(self, sql, columns):
@@ -378,33 +337,24 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
tokens = (token for token in statement.flatten() if not token.is_whitespace)
# Go to columns and constraint definition
for token in tokens:
if token.match(sqlparse.tokens.Punctuation, "("):
if token.match(sqlparse.tokens.Punctuation, '('):
break
# Parse columns and constraint definition
while True:
(
constraint_name,
unique,
check,
end_token,
) = self._parse_column_or_constraint_definition(tokens, columns)
constraint_name, unique, check, end_token = self._parse_column_or_constraint_definition(tokens, columns)
if unique:
if constraint_name:
constraints[constraint_name] = unique
else:
unnamed_constrains_index += 1
constraints[
"__unnamed_constraint_%s__" % unnamed_constrains_index
] = unique
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = unique
if check:
if constraint_name:
constraints[constraint_name] = check
else:
unnamed_constrains_index += 1
constraints[
"__unnamed_constraint_%s__" % unnamed_constrains_index
] = check
if end_token.match(sqlparse.tokens.Punctuation, ")"):
constraints['__unnamed_constraint_%s__' % unnamed_constrains_index] = check
if end_token.match(sqlparse.tokens.Punctuation, ')'):
break
return constraints
@@ -417,22 +367,19 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# Find inline check constraints.
try:
table_schema = cursor.execute(
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s"
% (self.connection.ops.quote_name(table_name),)
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s" % (
self.connection.ops.quote_name(table_name),
)
).fetchone()[0]
except TypeError:
# table_name is a view.
pass
else:
columns = {
info.name for info in self.get_table_description(cursor, table_name)
}
columns = {info.name for info in self.get_table_description(cursor, table_name)}
constraints.update(self._parse_table_constraints(table_schema, columns))
# Get the index info
cursor.execute(
"PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
)
cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
for row in cursor.fetchall():
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
# columns. Discard last 2 columns if there.
@@ -442,7 +389,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
)
# There's at most one row.
(sql,) = cursor.fetchone() or (None,)
sql, = cursor.fetchone() or (None,)
# Inline constraints are already detected in
# _parse_table_constraints(). The reasons to avoid fetching inline
# constraints from `PRAGMA index_list` are:
@@ -453,9 +400,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
# An inline constraint
continue
# Get the index info for that index
cursor.execute(
"PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
)
cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index))
for index_rank, column_rank, column in cursor.fetchall():
if index not in constraints:
constraints[index] = {
@@ -466,14 +411,14 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
"check": False,
"index": True,
}
constraints[index]["columns"].append(column)
constraints[index]['columns'].append(column)
# Add type and column orders for indexes
if constraints[index]["index"]:
if constraints[index]['index']:
# SQLite doesn't support any index type other than b-tree
constraints[index]["type"] = Index.suffix
constraints[index]['type'] = Index.suffix
orders = self._get_index_columns_orders(sql)
if orders is not None:
constraints[index]["orders"] = orders
constraints[index]['orders'] = orders
# Get the PK
pk_column = self.get_primary_key_column(cursor, table_name)
if pk_column:
@@ -496,30 +441,27 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
tokens = sqlparse.parse(sql)[0]
for token in tokens:
if isinstance(token, sqlparse.sql.Parenthesis):
columns = str(token).strip("()").split(", ")
return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
columns = str(token).strip('()').split(', ')
return ['DESC' if info.endswith('DESC') else 'ASC' for info in columns]
return None
def _get_column_collations(self, cursor, table_name):
row = cursor.execute(
"""
row = cursor.execute("""
SELECT sql
FROM sqlite_master
WHERE type = 'table' AND name = %s
""",
[table_name],
).fetchone()
""", [table_name]).fetchone()
if not row:
return {}
sql = row[0]
columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
columns = str(sqlparse.parse(sql)[0][-1]).strip('()').split(', ')
collations = {}
for column in columns:
tokens = column[1:].split()
column_name = tokens[0].strip('"')
for index, token in enumerate(tokens):
if token == "COLLATE":
if token == 'COLLATE':
collation = tokens[index + 1]
break
else:
@@ -15,15 +15,12 @@ from django.utils.functional import cached_property
class DatabaseOperations(BaseDatabaseOperations):
cast_char_field_without_max_length = "text"
cast_char_field_without_max_length = 'text'
cast_data_types = {
"DateField": "TEXT",
"DateTimeField": "TEXT",
'DateField': 'TEXT',
'DateTimeField': 'TEXT',
}
explain_prefix = "EXPLAIN QUERY PLAN"
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
# SQLite. Use JSON_TYPE() instead.
jsonfield_datatype_values = frozenset(["null", "false", "true"])
explain_prefix = 'EXPLAIN QUERY PLAN'
def bulk_batch_size(self, fields, objs):
"""
@@ -54,14 +51,14 @@ class DatabaseOperations(BaseDatabaseOperations):
else:
if isinstance(output_field, bad_fields):
raise NotSupportedError(
"You cannot use Sum, Avg, StdDev, and Variance "
"aggregations on date/time fields in sqlite3 "
"since date/time is saved as text."
'You cannot use Sum, Avg, StdDev, and Variance '
'aggregations on date/time fields in sqlite3 '
'since date/time is saved as text.'
)
if (
isinstance(expression, models.Aggregate)
and expression.distinct
and len(expression.source_expressions) > 1
isinstance(expression, models.Aggregate) and
expression.distinct and
len(expression.source_expressions) > 1
):
raise NotSupportedError(
"SQLite doesn't support DISTINCT on aggregate functions "
@@ -76,13 +73,6 @@ class DatabaseOperations(BaseDatabaseOperations):
"""
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the list of returned data.
"""
return cursor.fetchall()
def format_for_duration_arithmetic(self, sql):
"""Do nothing since formatting is handled in the custom function."""
return sql
@@ -104,32 +94,26 @@ class DatabaseOperations(BaseDatabaseOperations):
def _convert_tznames_to_sql(self, tzname):
if tzname and settings.USE_TZ:
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
return "NULL", "NULL"
return 'NULL', 'NULL'
def datetime_cast_date_sql(self, field_name, tzname):
return "django_datetime_cast_date(%s, %s, %s)" % (
field_name,
*self._convert_tznames_to_sql(tzname),
return 'django_datetime_cast_date(%s, %s, %s)' % (
field_name, *self._convert_tznames_to_sql(tzname),
)
def datetime_cast_time_sql(self, field_name, tzname):
return "django_datetime_cast_time(%s, %s, %s)" % (
field_name,
*self._convert_tznames_to_sql(tzname),
return 'django_datetime_cast_time(%s, %s, %s)' % (
field_name, *self._convert_tznames_to_sql(tzname),
)
def datetime_extract_sql(self, lookup_type, field_name, tzname):
return "django_datetime_extract('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
)
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
return "django_datetime_trunc('%s', %s, %s, %s)" % (
lookup_type.lower(),
field_name,
*self._convert_tznames_to_sql(tzname),
lookup_type.lower(), field_name, *self._convert_tznames_to_sql(tzname),
)
def time_extract_sql(self, lookup_type, field_name):
@@ -151,11 +135,11 @@ class DatabaseOperations(BaseDatabaseOperations):
if len(params) > BATCH_SIZE:
results = ()
for index in range(0, len(params), BATCH_SIZE):
chunk = params[index : index + BATCH_SIZE]
chunk = params[index:index + BATCH_SIZE]
results += self._quote_params_for_last_executed_query(chunk)
return results
sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
# Bypass Django's wrappers and use the underlying sqlite3 connection
# to avoid logging this query - it would trigger infinite recursion.
cursor = self.connection.connection.cursor()
@@ -167,9 +151,7 @@ class DatabaseOperations(BaseDatabaseOperations):
def last_executed_query(self, cursor, sql, params):
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
# pysqlite_statement_bind_parameters(
# self->statement, parameters, allow_8bit_chars
# );
# pysqlite_statement_bind_parameters(self->statement, parameters, allow_8bit_chars);
# Unfortunately there is no way to reach self->statement from Python,
# so we quote and substitute parameters manually.
if params:
@@ -222,20 +204,14 @@ class DatabaseOperations(BaseDatabaseOperations):
if tables and allow_cascade:
# Simulate TRUNCATE CASCADE by recursively collecting the tables
# referencing the tables to be flushed.
tables = set(
chain.from_iterable(self._references_graph(table) for table in tables)
)
sql = [
"%s %s %s;"
% (
style.SQL_KEYWORD("DELETE"),
style.SQL_KEYWORD("FROM"),
style.SQL_FIELD(self.quote_name(table)),
)
for table in tables
]
tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
sql = ['%s %s %s;' % (
style.SQL_KEYWORD('DELETE'),
style.SQL_KEYWORD('FROM'),
style.SQL_FIELD(self.quote_name(table))
) for table in tables]
if reset_sequences:
sequences = [{"table": table} for table in tables]
sequences = [{'table': table} for table in tables]
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
return sql
@@ -243,18 +219,17 @@ class DatabaseOperations(BaseDatabaseOperations):
if not sequences:
return []
return [
"%s %s %s %s = 0 %s %s %s (%s);"
% (
style.SQL_KEYWORD("UPDATE"),
style.SQL_TABLE(self.quote_name("sqlite_sequence")),
style.SQL_KEYWORD("SET"),
style.SQL_FIELD(self.quote_name("seq")),
style.SQL_KEYWORD("WHERE"),
style.SQL_FIELD(self.quote_name("name")),
style.SQL_KEYWORD("IN"),
", ".join(
["'%s'" % sequence_info["table"] for sequence_info in sequences]
),
'%s %s %s %s = 0 %s %s %s (%s);' % (
style.SQL_KEYWORD('UPDATE'),
style.SQL_TABLE(self.quote_name('sqlite_sequence')),
style.SQL_KEYWORD('SET'),
style.SQL_FIELD(self.quote_name('seq')),
style.SQL_KEYWORD('WHERE'),
style.SQL_FIELD(self.quote_name('name')),
style.SQL_KEYWORD('IN'),
', '.join([
"'%s'" % sequence_info['table'] for sequence_info in sequences
]),
),
]
@@ -263,7 +238,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
if hasattr(value, 'resolve_expression'):
return value
# SQLite doesn't support tz-aware datetimes
@@ -271,10 +246,7 @@ class DatabaseOperations(BaseDatabaseOperations):
if settings.USE_TZ:
value = timezone.make_naive(value, self.connection.timezone)
else:
raise ValueError(
"SQLite backend does not support timezone-aware datetimes when "
"USE_TZ is False."
)
raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")
return str(value)
@@ -283,7 +255,7 @@ class DatabaseOperations(BaseDatabaseOperations):
return None
# Expression values are adapted by the database.
if hasattr(value, "resolve_expression"):
if hasattr(value, 'resolve_expression'):
return value
# SQLite doesn't support tz-aware datetimes
@@ -295,17 +267,17 @@ class DatabaseOperations(BaseDatabaseOperations):
def get_db_converters(self, expression):
converters = super().get_db_converters(expression)
internal_type = expression.output_field.get_internal_type()
if internal_type == "DateTimeField":
if internal_type == 'DateTimeField':
converters.append(self.convert_datetimefield_value)
elif internal_type == "DateField":
elif internal_type == 'DateField':
converters.append(self.convert_datefield_value)
elif internal_type == "TimeField":
elif internal_type == 'TimeField':
converters.append(self.convert_timefield_value)
elif internal_type == "DecimalField":
elif internal_type == 'DecimalField':
converters.append(self.get_decimalfield_converter(expression))
elif internal_type == "UUIDField":
elif internal_type == 'UUIDField':
converters.append(self.convert_uuidfield_value)
elif internal_type == "BooleanField":
elif internal_type in ('NullBooleanField', 'BooleanField'):
converters.append(self.convert_booleanfield_value)
return converters
@@ -334,22 +306,15 @@ class DatabaseOperations(BaseDatabaseOperations):
# float inaccuracy must be removed.
create_decimal = decimal.Context(prec=15).create_decimal_from_float
if isinstance(expression, Col):
quantize_value = decimal.Decimal(1).scaleb(
-expression.output_field.decimal_places
)
quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places)
def converter(value, expression, connection):
if value is not None:
return create_decimal(value).quantize(
quantize_value, context=expression.output_field.context
)
return create_decimal(value).quantize(quantize_value, context=expression.output_field.context)
else:
def converter(value, expression, connection):
if value is not None:
return create_decimal(value)
return converter
def convert_uuidfield_value(self, value, expression, connection):
@@ -362,25 +327,26 @@ class DatabaseOperations(BaseDatabaseOperations):
def bulk_insert_sql(self, fields, placeholder_rows):
return " UNION ALL ".join(
"SELECT %s" % ", ".join(row) for row in placeholder_rows
"SELECT %s" % ", ".join(row)
for row in placeholder_rows
)
def combine_expression(self, connector, sub_expressions):
# SQLite doesn't have a ^ operator, so use the user-defined POWER
# function that's registered in connect().
if connector == "^":
return "POWER(%s)" % ",".join(sub_expressions)
elif connector == "#":
return "BITXOR(%s)" % ",".join(sub_expressions)
if connector == '^':
return 'POWER(%s)' % ','.join(sub_expressions)
elif connector == '#':
return 'BITXOR(%s)' % ','.join(sub_expressions)
return super().combine_expression(connector, sub_expressions)
def combine_duration_expression(self, connector, sub_expressions):
if connector not in ["+", "-", "*", "/"]:
raise DatabaseError("Invalid connector for timedelta: %s." % connector)
if connector not in ['+', '-']:
raise DatabaseError('Invalid connector for timedelta: %s.' % connector)
fn_params = ["'%s'" % connector] + sub_expressions
if len(fn_params) > 3:
raise ValueError("Too many params for timedelta operations.")
return "django_format_dtdelta(%s)" % ", ".join(fn_params)
raise ValueError('Too many params for timedelta operations.')
return "django_format_dtdelta(%s)" % ', '.join(fn_params)
def integer_field_range(self, internal_type):
# SQLite doesn't enforce any integer constraints
@@ -390,27 +356,9 @@ class DatabaseOperations(BaseDatabaseOperations):
lhs_sql, lhs_params = lhs
rhs_sql, rhs_params = rhs
params = (*lhs_params, *rhs_params)
if internal_type == "TimeField":
return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
if internal_type == 'TimeField':
return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
def insert_statement(self, ignore_conflicts=False):
return (
"INSERT OR IGNORE INTO"
if ignore_conflicts
else super().insert_statement(ignore_conflicts)
)
def return_insert_columns(self, fields):
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
if not fields:
return "", ()
columns = [
"%s.%s"
% (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
)
for field in fields
]
return "RETURNING %s" % ", ".join(columns), ()
return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
@@ -14,9 +14,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
sql_delete_table = "DROP TABLE %(table)s"
sql_create_fk = None
sql_create_inline_fk = (
"REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
)
sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
sql_delete_unique = "DROP INDEX %(name)s"
@@ -25,11 +23,11 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# disabled. Enforce it here for the duration of the schema edition.
if not self.connection.disable_constraint_checking():
raise NotSupportedError(
"SQLite schema editor cannot be used while foreign key "
"constraint checks are enabled. Make sure to disable them "
"before entering a transaction.atomic() context because "
"SQLite does not support disabling them in the middle of "
"a multi-statement transaction."
'SQLite schema editor cannot be used while foreign key '
'constraint checks are enabled. Make sure to disable them '
'before entering a transaction.atomic() context because '
'SQLite does not support disabling them in the middle of '
'a multi-statement transaction.'
)
return super().__enter__()
@@ -44,7 +42,6 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# security hardening).
try:
import sqlite3
value = sqlite3.adapt(value)
except ImportError:
pass
@@ -56,7 +53,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
elif isinstance(value, (Decimal, float, int)):
return str(value)
elif isinstance(value, str):
return "'%s'" % value.replace("'", "''")
return "'%s'" % value.replace("\'", "\'\'")
elif value is None:
return "NULL"
elif isinstance(value, (bytes, bytearray, memoryview)):
@@ -65,13 +62,9 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# character.
return "X'%s'" % value.hex()
else:
raise ValueError(
"Cannot quote parameter value %r of type %s" % (value, type(value))
)
raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value)))
def _is_referenced_by_fk_constraint(
self, table_name, column_name=None, ignore_self=False
):
def _is_referenced_by_fk_constraint(self, table_name, column_name=None, ignore_self=False):
"""
Return whether or not the provided table name is referenced by another
one. If `column_name` is specified, only references pointing to that
@@ -82,36 +75,23 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
for other_table in self.connection.introspection.get_table_list(cursor):
if ignore_self and other_table.name == table_name:
continue
constraints = (
self.connection.introspection._get_foreign_key_constraints(
cursor, other_table.name
)
)
constraints = self.connection.introspection._get_foreign_key_constraints(cursor, other_table.name)
for constraint in constraints.values():
constraint_table, constraint_column = constraint["foreign_key"]
if constraint_table == table_name and (
column_name is None or constraint_column == column_name
):
constraint_table, constraint_column = constraint['foreign_key']
if (constraint_table == table_name and
(column_name is None or constraint_column == column_name)):
return True
return False
def alter_db_table(
self, model, old_db_table, new_db_table, disable_constraints=True
):
if (
not self.connection.features.supports_atomic_references_rename
and disable_constraints
and self._is_referenced_by_fk_constraint(old_db_table)
):
def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True):
if (not self.connection.features.supports_atomic_references_rename and
disable_constraints and self._is_referenced_by_fk_constraint(old_db_table)):
if self.connection.in_atomic_block:
raise NotSupportedError(
(
"Renaming the %r table while in a transaction is not "
"supported on SQLite < 3.26 because it would break referential "
"integrity. Try adding `atomic = False` to the Migration class."
)
% old_db_table
)
raise NotSupportedError((
'Renaming the %r table while in a transaction is not '
'supported on SQLite < 3.26 because it would break referential '
'integrity. Try adding `atomic = False` to the Migration class.'
) % old_db_table)
self.connection.enable_constraint_checking()
super().alter_db_table(model, old_db_table, new_db_table)
self.connection.disable_constraint_checking()
@@ -124,56 +104,42 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
old_field_name = old_field.name
table_name = model._meta.db_table
_, old_column_name = old_field.get_attname_column()
if (
new_field.name != old_field_name
and not self.connection.features.supports_atomic_references_rename
and self._is_referenced_by_fk_constraint(
table_name, old_column_name, ignore_self=True
)
):
if (new_field.name != old_field_name and
not self.connection.features.supports_atomic_references_rename and
self._is_referenced_by_fk_constraint(table_name, old_column_name, ignore_self=True)):
if self.connection.in_atomic_block:
raise NotSupportedError(
(
"Renaming the %r.%r column while in a transaction is not "
"supported on SQLite < 3.26 because it would break referential "
"integrity. Try adding `atomic = False` to the Migration class."
)
% (model._meta.db_table, old_field_name)
)
raise NotSupportedError((
'Renaming the %r.%r column while in a transaction is not '
'supported on SQLite < 3.26 because it would break referential '
'integrity. Try adding `atomic = False` to the Migration class.'
) % (model._meta.db_table, old_field_name))
with atomic(self.connection.alias):
super().alter_field(model, old_field, new_field, strict=strict)
# Follow SQLite's documented procedure for performing changes
# that don't affect the on-disk content.
# https://sqlite.org/lang_altertable.html#otheralter
with self.connection.cursor() as cursor:
schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
0
]
cursor.execute("PRAGMA writable_schema = 1")
schema_version = cursor.execute('PRAGMA schema_version').fetchone()[0]
cursor.execute('PRAGMA writable_schema = 1')
references_template = ' REFERENCES "%s" ("%%s") ' % table_name
new_column_name = new_field.get_attname_column()[1]
search = references_template % old_column_name
replacement = references_template % new_column_name
cursor.execute(
"UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
(search, replacement),
)
cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))
cursor.execute("PRAGMA writable_schema = 0")
cursor.execute('UPDATE sqlite_master SET sql = replace(sql, %s, %s)', (search, replacement))
cursor.execute('PRAGMA schema_version = %d' % (schema_version + 1))
cursor.execute('PRAGMA writable_schema = 0')
# The integrity check will raise an exception and rollback
# the transaction if the sqlite_master updates corrupt the
# database.
cursor.execute("PRAGMA integrity_check")
cursor.execute('PRAGMA integrity_check')
# Perform a VACUUM to refresh the database representation from
# the sqlite_master table.
with self.connection.cursor() as cursor:
cursor.execute("VACUUM")
cursor.execute('VACUUM')
else:
super().alter_field(model, old_field, new_field, strict=strict)
def _remake_table(
self, model, create_field=None, delete_field=None, alter_field=None
):
def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None):
"""
Shortcut to transform a model from old_model into new_model
@@ -194,7 +160,6 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# to an altered field.
def is_self_referential(f):
return f.is_relation and f.remote_field.model is model
# Work out the new fields dict / mapping
body = {
f.name: f.clone() if is_self_referential(f) else f
@@ -202,18 +167,14 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
}
# Since mapping might mix column names and default values,
# its values must be already quoted.
mapping = {
f.column: self.quote_name(f.column)
for f in model._meta.local_concrete_fields
}
mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields}
# This maps field names (not columns) for things like unique_together
rename_mapping = {}
# If any of the new or altered fields is introducing a new PK,
# remove the old one
restore_pk_field = None
if getattr(create_field, "primary_key", False) or (
alter_field and getattr(alter_field[1], "primary_key", False)
):
if getattr(create_field, 'primary_key', False) or (
alter_field and getattr(alter_field[1], 'primary_key', False)):
for name, field in list(body.items()):
if field.primary_key:
field.primary_key = False
@@ -237,8 +198,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
body[new_field.name] = new_field
if old_field.null and not new_field.null:
case_sql = "coalesce(%(col)s, %(default)s)" % {
"col": self.quote_name(old_field.column),
"default": self.quote_value(self.effective_default(new_field)),
'col': self.quote_name(old_field.column),
'default': self.quote_value(self.effective_default(new_field))
}
mapping[new_field.column] = case_sql
else:
@@ -249,10 +210,7 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
del body[delete_field.name]
del mapping[delete_field.column]
# Remove any implicit M2M tables
if (
delete_field.many_to_many
and delete_field.remote_field.through._meta.auto_created
):
if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created:
return self.delete_model(delete_field.remote_field.through)
# Work inside a new app registry
apps = Apps()
@@ -274,7 +232,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
indexes = model._meta.indexes
if delete_field:
indexes = [
index for index in indexes if delete_field.name not in index.fields
index for index in indexes
if delete_field.name not in index.fields
]
constraints = list(model._meta.constraints)
@@ -290,57 +249,52 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# This wouldn't be required if the schema editor was operating on model
# states instead of rendered models.
meta_contents = {
"app_label": model._meta.app_label,
"db_table": model._meta.db_table,
"unique_together": unique_together,
"index_together": index_together,
"indexes": indexes,
"constraints": constraints,
"apps": apps,
'app_label': model._meta.app_label,
'db_table': model._meta.db_table,
'unique_together': unique_together,
'index_together': index_together,
'indexes': indexes,
'constraints': constraints,
'apps': apps,
}
meta = type("Meta", (), meta_contents)
body_copy["Meta"] = meta
body_copy["__module__"] = model.__module__
body_copy['Meta'] = meta
body_copy['__module__'] = model.__module__
type(model._meta.object_name, model.__bases__, body_copy)
# Construct a model with a renamed table name.
body_copy = copy.deepcopy(body)
meta_contents = {
"app_label": model._meta.app_label,
"db_table": "new__%s" % strip_quotes(model._meta.db_table),
"unique_together": unique_together,
"index_together": index_together,
"indexes": indexes,
"constraints": constraints,
"apps": apps,
'app_label': model._meta.app_label,
'db_table': 'new__%s' % strip_quotes(model._meta.db_table),
'unique_together': unique_together,
'index_together': index_together,
'indexes': indexes,
'constraints': constraints,
'apps': apps,
}
meta = type("Meta", (), meta_contents)
body_copy["Meta"] = meta
body_copy["__module__"] = model.__module__
new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy)
body_copy['Meta'] = meta
body_copy['__module__'] = model.__module__
new_model = type('New%s' % model._meta.object_name, model.__bases__, body_copy)
# Create a new table with the updated schema.
self.create_model(new_model)
# Copy data from the old table into the new table
self.execute(
"INSERT INTO %s (%s) SELECT %s FROM %s"
% (
self.quote_name(new_model._meta.db_table),
", ".join(self.quote_name(x) for x in mapping),
", ".join(mapping.values()),
self.quote_name(model._meta.db_table),
)
)
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
self.quote_name(new_model._meta.db_table),
', '.join(self.quote_name(x) for x in mapping),
', '.join(mapping.values()),
self.quote_name(model._meta.db_table),
))
# Delete the old table to make way for the new
self.delete_model(model, handle_autom2m=False)
# Rename the new table to take way for the old
self.alter_db_table(
new_model,
new_model._meta.db_table,
model._meta.db_table,
new_model, new_model._meta.db_table, model._meta.db_table,
disable_constraints=False,
)
@@ -357,17 +311,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
super().delete_model(model)
else:
# Delete the table (and only that)
self.execute(
self.sql_delete_table
% {
"table": self.quote_name(model._meta.db_table),
}
)
self.execute(self.sql_delete_table % {
"table": self.quote_name(model._meta.db_table),
})
# Remove all deferred statements referencing the deleted table.
for sql in list(self.deferred_sql):
if isinstance(sql, Statement) and sql.references_table(
model._meta.db_table
):
if isinstance(sql, Statement) and sql.references_table(model._meta.db_table):
self.deferred_sql.remove(sql)
def add_field(self, model, field):
@@ -394,40 +343,21 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# For everything else, remake.
else:
# It might not actually have a column behind it
if field.db_parameters(connection=self.connection)["type"] is None:
if field.db_parameters(connection=self.connection)['type'] is None:
return
self._remake_table(model, delete_field=field)
def _alter_field(
self,
model,
old_field,
new_field,
old_type,
new_type,
old_db_params,
new_db_params,
strict=False,
):
def _alter_field(self, model, old_field, new_field, old_type, new_type,
old_db_params, new_db_params, strict=False):
"""Perform a "physical" (non-ManyToMany) field update."""
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
# changed and there aren't any constraints.
if (
self.connection.features.can_alter_table_rename_column
and old_field.column != new_field.column
and self.column_sql(model, old_field) == self.column_sql(model, new_field)
and not (
old_field.remote_field
and old_field.db_constraint
or new_field.remote_field
and new_field.db_constraint
)
):
return self.execute(
self._rename_field_sql(
model._meta.db_table, old_field, new_field, new_type
)
)
if (self.connection.features.can_alter_table_rename_column and
old_field.column != new_field.column and
self.column_sql(model, old_field) == self.column_sql(model, new_field) and
not (old_field.remote_field and old_field.db_constraint or
new_field.remote_field and new_field.db_constraint)):
return self.execute(self._rename_field_sql(model._meta.db_table, old_field, new_field, new_type))
# Alter by remaking table
self._remake_table(model, alter_field=(old_field, new_field))
# Rebuild tables with FKs pointing to this field.
@@ -455,25 +385,15 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def _alter_many_to_many(self, model, old_field, new_field, strict):
"""Alter M2Ms to repoint their to= endpoints."""
if (
old_field.remote_field.through._meta.db_table
== new_field.remote_field.through._meta.db_table
):
# The field name didn't change, but some options did, so we have to
# propagate this altering.
if old_field.remote_field.through._meta.db_table == new_field.remote_field.through._meta.db_table:
# The field name didn't change, but some options did; we have to propagate this altering.
self._remake_table(
old_field.remote_field.through,
alter_field=(
# The field that points to the target model is needed, so
# we can tell alter_field to change it - this is
# m2m_reverse_field_name() (as opposed to m2m_field_name(),
# which points to our model).
old_field.remote_field.through._meta.get_field(
old_field.m2m_reverse_field_name()
),
new_field.remote_field.through._meta.get_field(
new_field.m2m_reverse_field_name()
),
# We need the field that points to the target model, so we can tell alter_field to change it -
# this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model)
old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()),
new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()),
),
)
return
@@ -481,51 +401,34 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
# Make a new through table
self.create_model(new_field.remote_field.through)
# Copy the data across
self.execute(
"INSERT INTO %s (%s) SELECT %s FROM %s"
% (
self.quote_name(new_field.remote_field.through._meta.db_table),
", ".join(
[
"id",
new_field.m2m_column_name(),
new_field.m2m_reverse_name(),
]
),
", ".join(
[
"id",
old_field.m2m_column_name(),
old_field.m2m_reverse_name(),
]
),
self.quote_name(old_field.remote_field.through._meta.db_table),
)
)
self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % (
self.quote_name(new_field.remote_field.through._meta.db_table),
', '.join([
"id",
new_field.m2m_column_name(),
new_field.m2m_reverse_name(),
]),
', '.join([
"id",
old_field.m2m_column_name(),
old_field.m2m_reverse_name(),
]),
self.quote_name(old_field.remote_field.through._meta.db_table),
))
# Delete the old through table
self.delete_model(old_field.remote_field.through)
def add_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
constraint.condition
or constraint.contains_expressions
or constraint.include
or constraint.deferrable
):
if isinstance(constraint, UniqueConstraint) and constraint.condition:
super().add_constraint(model, constraint)
else:
self._remake_table(model)
def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and (
constraint.condition
or constraint.contains_expressions
or constraint.include
or constraint.deferrable
):
if isinstance(constraint, UniqueConstraint) and constraint.condition:
super().remove_constraint(model, constraint)
else:
self._remake_table(model)
def _collate_sql(self, collation):
return "COLLATE " + collation
return ' COLLATE ' + collation