Initial commit
This commit is contained in:
@@ -0,0 +1,695 @@
|
||||
"""
|
||||
SQLite backend for the sqlite3 module in the standard library.
|
||||
"""
|
||||
import datetime
|
||||
import decimal
|
||||
import functools
|
||||
import hashlib
|
||||
import math
|
||||
import operator
|
||||
import random
|
||||
import re
|
||||
import statistics
|
||||
import warnings
|
||||
from itertools import chain
|
||||
from sqlite3 import dbapi2 as Database
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper, timezone_constructor
|
||||
from django.utils import timezone
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.dateparse import parse_datetime, parse_time
|
||||
from django.utils.duration import duration_microseconds
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
from .client import DatabaseClient
|
||||
from .creation import DatabaseCreation
|
||||
from .features import DatabaseFeatures
|
||||
from .introspection import DatabaseIntrospection
|
||||
from .operations import DatabaseOperations
|
||||
from .schema import DatabaseSchemaEditor
|
||||
|
||||
|
||||
def decoder(conv_func):
|
||||
"""
|
||||
Convert bytestrings from Python's sqlite3 interface to a regular string.
|
||||
"""
|
||||
return lambda s: conv_func(s.decode())
|
||||
|
||||
|
||||
def none_guard(func):
|
||||
"""
|
||||
Decorator that returns None if any of the arguments to the decorated
|
||||
function are None. Many SQL functions return NULL if any of their arguments
|
||||
are NULL. This decorator simplifies the implementation of this for the
|
||||
custom functions registered below.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return None if None in args else func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def list_aggregate(function):
|
||||
"""
|
||||
Return an aggregate class that accumulates values in a list and applies
|
||||
the provided function to the data.
|
||||
"""
|
||||
return type("ListAggregate", (list,), {"finalize": function, "step": list.append})
|
||||
|
||||
|
||||
def check_sqlite_version():
|
||||
if Database.sqlite_version_info < (3, 9, 0):
|
||||
raise ImproperlyConfigured(
|
||||
"SQLite 3.9.0 or later is required (found %s)." % Database.sqlite_version
|
||||
)
|
||||
|
||||
|
||||
check_sqlite_version()
|
||||
|
||||
Database.register_converter("bool", b"1".__eq__)
|
||||
Database.register_converter("time", decoder(parse_time))
|
||||
Database.register_converter("datetime", decoder(parse_datetime))
|
||||
Database.register_converter("timestamp", decoder(parse_datetime))
|
||||
|
||||
Database.register_adapter(decimal.Decimal, str)
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "sqlite"
|
||||
display_name = "SQLite"
|
||||
# SQLite doesn't actually support most of these types, but it "does the right
|
||||
# thing" given more verbose field definitions, so leave them as is so that
|
||||
# schema inspection is more useful.
|
||||
data_types = {
|
||||
"AutoField": "integer",
|
||||
"BigAutoField": "integer",
|
||||
"BinaryField": "BLOB",
|
||||
"BooleanField": "bool",
|
||||
"CharField": "varchar(%(max_length)s)",
|
||||
"DateField": "date",
|
||||
"DateTimeField": "datetime",
|
||||
"DecimalField": "decimal",
|
||||
"DurationField": "bigint",
|
||||
"FileField": "varchar(%(max_length)s)",
|
||||
"FilePathField": "varchar(%(max_length)s)",
|
||||
"FloatField": "real",
|
||||
"IntegerField": "integer",
|
||||
"BigIntegerField": "bigint",
|
||||
"IPAddressField": "char(15)",
|
||||
"GenericIPAddressField": "char(39)",
|
||||
"JSONField": "text",
|
||||
"OneToOneField": "integer",
|
||||
"PositiveBigIntegerField": "bigint unsigned",
|
||||
"PositiveIntegerField": "integer unsigned",
|
||||
"PositiveSmallIntegerField": "smallint unsigned",
|
||||
"SlugField": "varchar(%(max_length)s)",
|
||||
"SmallAutoField": "integer",
|
||||
"SmallIntegerField": "smallint",
|
||||
"TextField": "text",
|
||||
"TimeField": "time",
|
||||
"UUIDField": "char(32)",
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
"PositiveBigIntegerField": '"%(column)s" >= 0',
|
||||
"JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
|
||||
"PositiveIntegerField": '"%(column)s" >= 0',
|
||||
"PositiveSmallIntegerField": '"%(column)s" >= 0',
|
||||
}
|
||||
data_types_suffix = {
|
||||
"AutoField": "AUTOINCREMENT",
|
||||
"BigAutoField": "AUTOINCREMENT",
|
||||
"SmallAutoField": "AUTOINCREMENT",
|
||||
}
|
||||
# SQLite requires LIKE statements to include an ESCAPE clause if the value
|
||||
# being escaped has a percent or underscore in it.
|
||||
# See https://www.sqlite.org/lang_expr.html for an explanation.
|
||||
operators = {
|
||||
"exact": "= %s",
|
||||
"iexact": "LIKE %s ESCAPE '\\'",
|
||||
"contains": "LIKE %s ESCAPE '\\'",
|
||||
"icontains": "LIKE %s ESCAPE '\\'",
|
||||
"regex": "REGEXP %s",
|
||||
"iregex": "REGEXP '(?i)' || %s",
|
||||
"gt": "> %s",
|
||||
"gte": ">= %s",
|
||||
"lt": "< %s",
|
||||
"lte": "<= %s",
|
||||
"startswith": "LIKE %s ESCAPE '\\'",
|
||||
"endswith": "LIKE %s ESCAPE '\\'",
|
||||
"istartswith": "LIKE %s ESCAPE '\\'",
|
||||
"iendswith": "LIKE %s ESCAPE '\\'",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
|
||||
pattern_ops = {
|
||||
"contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
|
||||
"icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
|
||||
"startswith": r"LIKE {} || '%%' ESCAPE '\'",
|
||||
"istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
|
||||
"endswith": r"LIKE '%%' || {} ESCAPE '\'",
|
||||
"iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
|
||||
def get_connection_params(self):
|
||||
settings_dict = self.settings_dict
|
||||
if not settings_dict["NAME"]:
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the NAME value."
|
||||
)
|
||||
kwargs = {
|
||||
"database": settings_dict["NAME"],
|
||||
"detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
|
||||
**settings_dict["OPTIONS"],
|
||||
}
|
||||
# Always allow the underlying SQLite connection to be shareable
|
||||
# between multiple threads. The safe-guarding will be handled at a
|
||||
# higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
|
||||
# property. This is necessary as the shareability is disabled by
|
||||
# default in pysqlite and it cannot be changed once a connection is
|
||||
# opened.
|
||||
if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
|
||||
warnings.warn(
|
||||
"The `check_same_thread` option was provided and set to "
|
||||
"True. It will be overridden with False. Use the "
|
||||
"`DatabaseWrapper.allow_thread_sharing` property instead "
|
||||
"for controlling thread shareability.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
kwargs.update({"check_same_thread": False, "uri": True})
|
||||
return kwargs
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
conn = Database.connect(**conn_params)
|
||||
create_deterministic_function = functools.partial(
|
||||
conn.create_function,
|
||||
deterministic=True,
|
||||
)
|
||||
create_deterministic_function(
|
||||
"django_date_extract", 2, _sqlite_datetime_extract
|
||||
)
|
||||
create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc)
|
||||
create_deterministic_function(
|
||||
"django_datetime_cast_date", 3, _sqlite_datetime_cast_date
|
||||
)
|
||||
create_deterministic_function(
|
||||
"django_datetime_cast_time", 3, _sqlite_datetime_cast_time
|
||||
)
|
||||
create_deterministic_function(
|
||||
"django_datetime_extract", 4, _sqlite_datetime_extract
|
||||
)
|
||||
create_deterministic_function(
|
||||
"django_datetime_trunc", 4, _sqlite_datetime_trunc
|
||||
)
|
||||
create_deterministic_function("django_time_extract", 2, _sqlite_time_extract)
|
||||
create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc)
|
||||
create_deterministic_function("django_time_diff", 2, _sqlite_time_diff)
|
||||
create_deterministic_function(
|
||||
"django_timestamp_diff", 2, _sqlite_timestamp_diff
|
||||
)
|
||||
create_deterministic_function(
|
||||
"django_format_dtdelta", 3, _sqlite_format_dtdelta
|
||||
)
|
||||
create_deterministic_function("regexp", 2, _sqlite_regexp)
|
||||
create_deterministic_function("ACOS", 1, none_guard(math.acos))
|
||||
create_deterministic_function("ASIN", 1, none_guard(math.asin))
|
||||
create_deterministic_function("ATAN", 1, none_guard(math.atan))
|
||||
create_deterministic_function("ATAN2", 2, none_guard(math.atan2))
|
||||
create_deterministic_function("BITXOR", 2, none_guard(operator.xor))
|
||||
create_deterministic_function("CEILING", 1, none_guard(math.ceil))
|
||||
create_deterministic_function("COS", 1, none_guard(math.cos))
|
||||
create_deterministic_function("COT", 1, none_guard(lambda x: 1 / math.tan(x)))
|
||||
create_deterministic_function("DEGREES", 1, none_guard(math.degrees))
|
||||
create_deterministic_function("EXP", 1, none_guard(math.exp))
|
||||
create_deterministic_function("FLOOR", 1, none_guard(math.floor))
|
||||
create_deterministic_function("LN", 1, none_guard(math.log))
|
||||
create_deterministic_function("LOG", 2, none_guard(lambda x, y: math.log(y, x)))
|
||||
create_deterministic_function("LPAD", 3, _sqlite_lpad)
|
||||
create_deterministic_function(
|
||||
"MD5", 1, none_guard(lambda x: hashlib.md5(x.encode()).hexdigest())
|
||||
)
|
||||
create_deterministic_function("MOD", 2, none_guard(math.fmod))
|
||||
create_deterministic_function("PI", 0, lambda: math.pi)
|
||||
create_deterministic_function("POWER", 2, none_guard(operator.pow))
|
||||
create_deterministic_function("RADIANS", 1, none_guard(math.radians))
|
||||
create_deterministic_function("REPEAT", 2, none_guard(operator.mul))
|
||||
create_deterministic_function("REVERSE", 1, none_guard(lambda x: x[::-1]))
|
||||
create_deterministic_function("RPAD", 3, _sqlite_rpad)
|
||||
create_deterministic_function(
|
||||
"SHA1", 1, none_guard(lambda x: hashlib.sha1(x.encode()).hexdigest())
|
||||
)
|
||||
create_deterministic_function(
|
||||
"SHA224", 1, none_guard(lambda x: hashlib.sha224(x.encode()).hexdigest())
|
||||
)
|
||||
create_deterministic_function(
|
||||
"SHA256", 1, none_guard(lambda x: hashlib.sha256(x.encode()).hexdigest())
|
||||
)
|
||||
create_deterministic_function(
|
||||
"SHA384", 1, none_guard(lambda x: hashlib.sha384(x.encode()).hexdigest())
|
||||
)
|
||||
create_deterministic_function(
|
||||
"SHA512", 1, none_guard(lambda x: hashlib.sha512(x.encode()).hexdigest())
|
||||
)
|
||||
create_deterministic_function(
|
||||
"SIGN", 1, none_guard(lambda x: (x > 0) - (x < 0))
|
||||
)
|
||||
create_deterministic_function("SIN", 1, none_guard(math.sin))
|
||||
create_deterministic_function("SQRT", 1, none_guard(math.sqrt))
|
||||
create_deterministic_function("TAN", 1, none_guard(math.tan))
|
||||
# Don't use the built-in RANDOM() function because it returns a value
|
||||
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
|
||||
conn.create_function("RAND", 0, random.random)
|
||||
conn.create_aggregate("STDDEV_POP", 1, list_aggregate(statistics.pstdev))
|
||||
conn.create_aggregate("STDDEV_SAMP", 1, list_aggregate(statistics.stdev))
|
||||
conn.create_aggregate("VAR_POP", 1, list_aggregate(statistics.pvariance))
|
||||
conn.create_aggregate("VAR_SAMP", 1, list_aggregate(statistics.variance))
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
return conn
|
||||
|
||||
def init_connection_state(self):
|
||||
pass
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
return self.connection.cursor(factory=SQLiteCursorWrapper)
|
||||
|
||||
@async_unsafe
|
||||
def close(self):
|
||||
self.validate_thread_sharing()
|
||||
# If database is in memory, closing the connection destroys the
|
||||
# database. To prevent accidental data loss, ignore close requests on
|
||||
# an in-memory db.
|
||||
if not self.is_in_memory_db():
|
||||
BaseDatabaseWrapper.close(self)
|
||||
|
||||
def _savepoint_allowed(self):
|
||||
# When 'isolation_level' is not None, sqlite3 commits before each
|
||||
# savepoint; it's a bug. When it is None, savepoints don't make sense
|
||||
# because autocommit is enabled. The only exception is inside 'atomic'
|
||||
# blocks. To work around that bug, on SQLite, 'atomic' starts a
|
||||
# transaction explicitly rather than simply disable autocommit.
|
||||
return self.in_atomic_block
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
if autocommit:
|
||||
level = None
|
||||
else:
|
||||
# sqlite3's internal default is ''. It's different from None.
|
||||
# See Modules/_sqlite/connection.c.
|
||||
level = ""
|
||||
# 'isolation_level' is a misleading API.
|
||||
# SQLite always runs at the SERIALIZABLE isolation level.
|
||||
with self.wrap_database_errors:
|
||||
self.connection.isolation_level = level
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("PRAGMA foreign_keys = OFF")
|
||||
# Foreign key constraints cannot be turned off while in a multi-
|
||||
# statement transaction. Fetch the current state of the pragma
|
||||
# to determine if constraints are effectively disabled.
|
||||
enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
|
||||
return not bool(enabled)
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("PRAGMA foreign_keys = ON")
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check each table name in `table_names` for rows with invalid foreign
|
||||
key references. This method is intended to be used in conjunction with
|
||||
`disable_constraint_checking()` and `enable_constraint_checking()`, to
|
||||
determine if rows with invalid references were entered while constraint
|
||||
checks were off.
|
||||
"""
|
||||
if self.features.supports_pragma_foreign_key_check:
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
|
||||
else:
|
||||
violations = chain.from_iterable(
|
||||
cursor.execute(
|
||||
"PRAGMA foreign_key_check(%s)"
|
||||
% self.ops.quote_name(table_name)
|
||||
).fetchall()
|
||||
for table_name in table_names
|
||||
)
|
||||
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
|
||||
for (
|
||||
table_name,
|
||||
rowid,
|
||||
referenced_table_name,
|
||||
foreign_key_index,
|
||||
) in violations:
|
||||
foreign_key = cursor.execute(
|
||||
"PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name)
|
||||
).fetchall()[foreign_key_index]
|
||||
column_name, referenced_column_name = foreign_key[3:5]
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(
|
||||
cursor, table_name
|
||||
)
|
||||
primary_key_value, bad_value = cursor.execute(
|
||||
"SELECT %s, %s FROM %s WHERE rowid = %%s"
|
||||
% (
|
||||
self.ops.quote_name(primary_key_column_name),
|
||||
self.ops.quote_name(column_name),
|
||||
self.ops.quote_name(table_name),
|
||||
),
|
||||
(rowid,),
|
||||
).fetchone()
|
||||
raise IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s."
|
||||
% (
|
||||
table_name,
|
||||
primary_key_value,
|
||||
table_name,
|
||||
column_name,
|
||||
bad_value,
|
||||
referenced_table_name,
|
||||
referenced_column_name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
table_names = self.introspection.table_names(cursor)
|
||||
for table_name in table_names:
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(
|
||||
cursor, table_name
|
||||
)
|
||||
if not primary_key_column_name:
|
||||
continue
|
||||
key_columns = self.introspection.get_key_columns(cursor, table_name)
|
||||
for (
|
||||
column_name,
|
||||
referenced_table_name,
|
||||
referenced_column_name,
|
||||
) in key_columns:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
|
||||
LEFT JOIN `%s` as REFERRED
|
||||
ON (REFERRING.`%s` = REFERRED.`%s`)
|
||||
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
|
||||
"""
|
||||
% (
|
||||
primary_key_column_name,
|
||||
column_name,
|
||||
table_name,
|
||||
referenced_table_name,
|
||||
column_name,
|
||||
referenced_column_name,
|
||||
column_name,
|
||||
referenced_column_name,
|
||||
)
|
||||
)
|
||||
for bad_row in cursor.fetchall():
|
||||
raise IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s."
|
||||
% (
|
||||
table_name,
|
||||
bad_row[0],
|
||||
table_name,
|
||||
column_name,
|
||||
bad_row[1],
|
||||
referenced_table_name,
|
||||
referenced_column_name,
|
||||
)
|
||||
)
|
||||
|
||||
def is_usable(self):
|
||||
return True
|
||||
|
||||
def _start_transaction_under_autocommit(self):
|
||||
"""
|
||||
Start a transaction explicitly in autocommit mode.
|
||||
|
||||
Staying in autocommit mode works around a bug of sqlite3 that breaks
|
||||
savepoints when autocommit is disabled.
|
||||
"""
|
||||
self.cursor().execute("BEGIN")
|
||||
|
||||
def is_in_memory_db(self):
|
||||
return self.creation.is_in_memory_db(self.settings_dict["NAME"])
|
||||
|
||||
|
||||
FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
|
||||
|
||||
|
||||
class SQLiteCursorWrapper(Database.Cursor):
|
||||
"""
|
||||
Django uses "format" style placeholders, but pysqlite2 uses "qmark" style.
|
||||
This fixes it -- but note that if you want to use a literal "%s" in a query,
|
||||
you'll need to use "%%s".
|
||||
"""
|
||||
|
||||
def execute(self, query, params=None):
|
||||
if params is None:
|
||||
return Database.Cursor.execute(self, query)
|
||||
query = self.convert_query(query)
|
||||
return Database.Cursor.execute(self, query, params)
|
||||
|
||||
def executemany(self, query, param_list):
|
||||
query = self.convert_query(query)
|
||||
return Database.Cursor.executemany(self, query, param_list)
|
||||
|
||||
def convert_query(self, query):
|
||||
return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
|
||||
|
||||
|
||||
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = backend_utils.typecast_timestamp(dt)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if conn_tzname:
|
||||
dt = dt.replace(tzinfo=timezone_constructor(conn_tzname))
|
||||
if tzname is not None and tzname != conn_tzname:
|
||||
tzname, sign, offset = backend_utils.split_tzname_delta(tzname)
|
||||
if offset:
|
||||
hours, minutes = offset.split(":")
|
||||
offset_delta = datetime.timedelta(hours=int(hours), minutes=int(minutes))
|
||||
dt += offset_delta if sign == "+" else -offset_delta
|
||||
dt = timezone.localtime(dt, timezone_constructor(tzname))
|
||||
return dt
|
||||
|
||||
|
||||
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == "year":
|
||||
return "%i-01-01" % dt.year
|
||||
elif lookup_type == "quarter":
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return "%i-%02i-01" % (dt.year, month_in_quarter)
|
||||
elif lookup_type == "month":
|
||||
return "%i-%02i-01" % (dt.year, dt.month)
|
||||
elif lookup_type == "week":
|
||||
dt = dt - datetime.timedelta(days=dt.weekday())
|
||||
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
|
||||
elif lookup_type == "day":
|
||||
return "%i-%02i-%02i" % (dt.year, dt.month, dt.day)
|
||||
|
||||
|
||||
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
if dt is None:
|
||||
return None
|
||||
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt_parsed is None:
|
||||
try:
|
||||
dt = backend_utils.typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
else:
|
||||
dt = dt_parsed
|
||||
if lookup_type == "hour":
|
||||
return "%02i:00:00" % dt.hour
|
||||
elif lookup_type == "minute":
|
||||
return "%02i:%02i:00" % (dt.hour, dt.minute)
|
||||
elif lookup_type == "second":
|
||||
return "%02i:%02i:%02i" % (dt.hour, dt.minute, dt.second)
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.date().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.time().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == "week_day":
|
||||
return (dt.isoweekday() % 7) + 1
|
||||
elif lookup_type == "iso_week_day":
|
||||
return dt.isoweekday()
|
||||
elif lookup_type == "week":
|
||||
return dt.isocalendar()[1]
|
||||
elif lookup_type == "quarter":
|
||||
return math.ceil(dt.month / 3)
|
||||
elif lookup_type == "iso_year":
|
||||
return dt.isocalendar()[0]
|
||||
else:
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == "year":
|
||||
return "%i-01-01 00:00:00" % dt.year
|
||||
elif lookup_type == "quarter":
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return "%i-%02i-01 00:00:00" % (dt.year, month_in_quarter)
|
||||
elif lookup_type == "month":
|
||||
return "%i-%02i-01 00:00:00" % (dt.year, dt.month)
|
||||
elif lookup_type == "week":
|
||||
dt = dt - datetime.timedelta(days=dt.weekday())
|
||||
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
|
||||
elif lookup_type == "day":
|
||||
return "%i-%02i-%02i 00:00:00" % (dt.year, dt.month, dt.day)
|
||||
elif lookup_type == "hour":
|
||||
return "%i-%02i-%02i %02i:00:00" % (dt.year, dt.month, dt.day, dt.hour)
|
||||
elif lookup_type == "minute":
|
||||
return "%i-%02i-%02i %02i:%02i:00" % (
|
||||
dt.year,
|
||||
dt.month,
|
||||
dt.day,
|
||||
dt.hour,
|
||||
dt.minute,
|
||||
)
|
||||
elif lookup_type == "second":
|
||||
return "%i-%02i-%02i %02i:%02i:%02i" % (
|
||||
dt.year,
|
||||
dt.month,
|
||||
dt.day,
|
||||
dt.hour,
|
||||
dt.minute,
|
||||
dt.second,
|
||||
)
|
||||
|
||||
|
||||
def _sqlite_time_extract(lookup_type, dt):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = backend_utils.typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
def _sqlite_prepare_dtdelta_param(conn, param):
|
||||
if conn in ["+", "-"]:
|
||||
if isinstance(param, int):
|
||||
return datetime.timedelta(0, 0, param)
|
||||
else:
|
||||
return backend_utils.typecast_timestamp(param)
|
||||
return param
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_format_dtdelta(conn, lhs, rhs):
|
||||
"""
|
||||
LHS and RHS can be either:
|
||||
- An integer number of microseconds
|
||||
- A string representing a datetime
|
||||
- A scalar value, e.g. float
|
||||
"""
|
||||
conn = conn.strip()
|
||||
try:
|
||||
real_lhs = _sqlite_prepare_dtdelta_param(conn, lhs)
|
||||
real_rhs = _sqlite_prepare_dtdelta_param(conn, rhs)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if conn == "+":
|
||||
# typecast_timestamp returns a date or a datetime without timezone.
|
||||
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
|
||||
out = str(real_lhs + real_rhs)
|
||||
elif conn == "-":
|
||||
out = str(real_lhs - real_rhs)
|
||||
elif conn == "*":
|
||||
out = real_lhs * real_rhs
|
||||
else:
|
||||
out = real_lhs / real_rhs
|
||||
return out
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_time_diff(lhs, rhs):
|
||||
left = backend_utils.typecast_time(lhs)
|
||||
right = backend_utils.typecast_time(rhs)
|
||||
return (
|
||||
(left.hour * 60 * 60 * 1000000)
|
||||
+ (left.minute * 60 * 1000000)
|
||||
+ (left.second * 1000000)
|
||||
+ (left.microsecond)
|
||||
- (right.hour * 60 * 60 * 1000000)
|
||||
- (right.minute * 60 * 1000000)
|
||||
- (right.second * 1000000)
|
||||
- (right.microsecond)
|
||||
)
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_timestamp_diff(lhs, rhs):
|
||||
left = backend_utils.typecast_timestamp(lhs)
|
||||
right = backend_utils.typecast_timestamp(rhs)
|
||||
return duration_microseconds(left - right)
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_regexp(re_pattern, re_string):
|
||||
return bool(re.search(re_pattern, str(re_string)))
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_lpad(text, length, fill_text):
|
||||
if len(text) >= length:
|
||||
return text[:length]
|
||||
return (fill_text * length)[: length - len(text)] + text
|
||||
|
||||
|
||||
@none_guard
|
||||
def _sqlite_rpad(text, length, fill_text):
|
||||
return (text + fill_text * length)[:length]
|
||||
@@ -0,0 +1,10 @@
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = "sqlite3"
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name, settings_dict["NAME"], *parameters]
|
||||
return args, None
|
||||
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
@staticmethod
|
||||
def is_in_memory_db(database_name):
|
||||
return not isinstance(database_name, Path) and (
|
||||
database_name == ":memory:" or "mode=memory" in database_name
|
||||
)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:"
|
||||
if test_database_name == ":memory:":
|
||||
return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias
|
||||
return test_database_name
|
||||
|
||||
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
|
||||
test_database_name = self._get_test_db_name()
|
||||
|
||||
if keepdb:
|
||||
return test_database_name
|
||||
if not self.is_in_memory_db(test_database_name):
|
||||
# Erase the old test database
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (self._get_database_display_str(verbosity, test_database_name),)
|
||||
)
|
||||
if os.access(test_database_name, os.F_OK):
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"Type 'yes' if you would like to try deleting the test "
|
||||
"database '%s', or 'no' to cancel: " % test_database_name
|
||||
)
|
||||
if autoclobber or confirm == "yes":
|
||||
try:
|
||||
os.remove(test_database_name)
|
||||
except Exception as e:
|
||||
self.log("Got an error deleting the old test database: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled.")
|
||||
sys.exit(1)
|
||||
return test_database_name
|
||||
|
||||
def get_test_db_clone_settings(self, suffix):
|
||||
orig_settings_dict = self.connection.settings_dict
|
||||
source_database_name = orig_settings_dict["NAME"]
|
||||
if self.is_in_memory_db(source_database_name):
|
||||
return orig_settings_dict
|
||||
else:
|
||||
root, ext = os.path.splitext(orig_settings_dict["NAME"])
|
||||
return {**orig_settings_dict, "NAME": "{}_{}{}".format(root, suffix, ext)}
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
source_database_name = self.connection.settings_dict["NAME"]
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||||
# Forking automatically makes a copy of an in-memory database.
|
||||
if not self.is_in_memory_db(source_database_name):
|
||||
# Erase the old test database
|
||||
if os.access(target_database_name, os.F_OK):
|
||||
if keepdb:
|
||||
return
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (
|
||||
self._get_database_display_str(
|
||||
verbosity, target_database_name
|
||||
),
|
||||
)
|
||||
)
|
||||
try:
|
||||
os.remove(target_database_name)
|
||||
except Exception as e:
|
||||
self.log("Got an error deleting the old test database: %s" % e)
|
||||
sys.exit(2)
|
||||
try:
|
||||
shutil.copy(source_database_name, target_database_name)
|
||||
except Exception as e:
|
||||
self.log("Got an error cloning the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
if test_database_name and not self.is_in_memory_db(test_database_name):
|
||||
# Remove the SQLite database file
|
||||
os.remove(test_database_name)
|
||||
|
||||
def test_db_signature(self):
|
||||
"""
|
||||
Return a tuple that uniquely identifies a test database.
|
||||
|
||||
This takes into account the special cases of ":memory:" and "" for
|
||||
SQLite since the databases will be distinct despite having the same
|
||||
TEST NAME. See https://www.sqlite.org/inmemorydb.html
|
||||
"""
|
||||
test_database_name = self._get_test_db_name()
|
||||
sig = [self.connection.settings_dict["NAME"]]
|
||||
if self.is_in_memory_db(test_database_name):
|
||||
sig.append(self.connection.alias)
|
||||
else:
|
||||
sig.append(test_database_name)
|
||||
return tuple(sig)
|
||||
@@ -0,0 +1,145 @@
|
||||
import operator
|
||||
import platform
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.db.utils import OperationalError
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Database
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
# SQLite can read from a cursor since SQLite 3.6.5, subject to the caveat
|
||||
# that statements within a connection aren't isolated from each other. See
|
||||
# https://sqlite.org/isolation.html.
|
||||
can_use_chunked_reads = True
|
||||
test_db_allows_multiple_connections = False
|
||||
supports_unspecified_pk = True
|
||||
supports_timezones = False
|
||||
max_query_params = 999
|
||||
supports_mixed_date_datetime_comparisons = False
|
||||
supports_transactions = True
|
||||
atomic_transactions = False
|
||||
can_rollback_ddl = True
|
||||
can_create_inline_fk = False
|
||||
supports_paramstyle_pyformat = False
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
ignores_table_name_case = True
|
||||
supports_cast_with_precision = False
|
||||
time_cast_precision = 3
|
||||
can_release_savepoints = True
|
||||
# Is "ALTER TABLE ... RENAME COLUMN" supported?
|
||||
can_alter_table_rename_column = Database.sqlite_version_info >= (3, 25, 0)
|
||||
supports_parentheses_in_compound = False
|
||||
# Deferred constraint checks can be emulated on SQLite < 3.20 but not in a
|
||||
# reasonably performant way.
|
||||
supports_pragma_foreign_key_check = Database.sqlite_version_info >= (3, 20, 0)
|
||||
can_defer_constraint_checks = supports_pragma_foreign_key_check
|
||||
supports_functions_in_partial_indexes = Database.sqlite_version_info >= (3, 15, 0)
|
||||
supports_over_clause = Database.sqlite_version_info >= (3, 25, 0)
|
||||
supports_frame_range_fixed_distance = Database.sqlite_version_info >= (3, 28, 0)
|
||||
supports_aggregate_filter_clause = Database.sqlite_version_info >= (3, 30, 1)
|
||||
supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
|
||||
order_by_nulls_first = True
|
||||
supports_json_field_contains = False
|
||||
test_collations = {
|
||||
"ci": "nocase",
|
||||
"cs": "binary",
|
||||
"non_default": "nocase",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def django_test_skips(self):
|
||||
skips = {
|
||||
"SQLite stores values rounded to 15 significant digits.": {
|
||||
"model_fields.test_decimalfield.DecimalFieldTests."
|
||||
"test_fetch_from_db_without_float_rounding",
|
||||
},
|
||||
"SQLite naively remakes the table on field alteration.": {
|
||||
"schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops",
|
||||
"schema.tests.SchemaTests.test_unique_and_reverse_m2m",
|
||||
"schema.tests.SchemaTests."
|
||||
"test_alter_field_default_doesnt_perform_queries",
|
||||
"schema.tests.SchemaTests."
|
||||
"test_rename_column_renames_deferred_sql_references",
|
||||
},
|
||||
"SQLite doesn't have a constraint.": {
|
||||
"model_fields.test_integerfield.PositiveIntegerFieldTests."
|
||||
"test_negative_values",
|
||||
},
|
||||
"SQLite doesn't support negative precision for ROUND().": {
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_null_with_negative_precision",
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_decimal_with_negative_precision",
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_float_with_negative_precision",
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_integer_with_negative_precision",
|
||||
},
|
||||
}
|
||||
if Database.sqlite_version_info < (3, 27):
|
||||
skips.update(
|
||||
{
|
||||
"Nondeterministic failure on SQLite < 3.27.": {
|
||||
"expressions_window.tests.WindowFunctionTests."
|
||||
"test_subquery_row_range_rank",
|
||||
},
|
||||
}
|
||||
)
|
||||
if self.connection.is_in_memory_db():
|
||||
skips.update(
|
||||
{
|
||||
"the sqlite backend's close() method is a no-op when using an "
|
||||
"in-memory database": {
|
||||
"servers.test_liveserverthread.LiveServerThreadTest."
|
||||
"test_closes_connections",
|
||||
"servers.tests.LiveServerTestCloseConnectionTest."
|
||||
"test_closes_connections",
|
||||
},
|
||||
}
|
||||
)
|
||||
return skips
|
||||
|
||||
@cached_property
|
||||
def supports_atomic_references_rename(self):
|
||||
# SQLite 3.28.0 bundled with MacOS 10.15 does not support renaming
|
||||
# references atomically.
|
||||
if platform.mac_ver()[0].startswith(
|
||||
"10.15."
|
||||
) and Database.sqlite_version_info == (3, 28, 0):
|
||||
return False
|
||||
return Database.sqlite_version_info >= (3, 26, 0)
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
"BigAutoField": "AutoField",
|
||||
"DurationField": "BigIntegerField",
|
||||
"GenericIPAddressField": "CharField",
|
||||
"SmallAutoField": "AutoField",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def supports_json_field(self):
|
||||
with self.connection.cursor() as cursor:
|
||||
try:
|
||||
with transaction.atomic(self.connection.alias):
|
||||
cursor.execute('SELECT JSON(\'{"a": "b"}\')')
|
||||
except OperationalError:
|
||||
return False
|
||||
return True
|
||||
|
||||
can_introspect_json_field = property(operator.attrgetter("supports_json_field"))
|
||||
has_json_object_function = property(operator.attrgetter("supports_json_field"))
|
||||
|
||||
@cached_property
|
||||
def can_return_columns_from_insert(self):
|
||||
return Database.sqlite_version_info >= (3, 35)
|
||||
|
||||
can_return_rows_from_bulk_insert = property(
|
||||
operator.attrgetter("can_return_columns_from_insert")
|
||||
)
|
||||
@@ -0,0 +1,528 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlparse
|
||||
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
|
||||
from django.db.backends.base.introspection import TableInfo
|
||||
from django.db.models import Index
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
FieldInfo = namedtuple(
|
||||
"FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
|
||||
)
|
||||
|
||||
field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
|
||||
|
||||
|
||||
def get_field_size(name):
|
||||
"""Extract the size number from a "varchar(11)" type name"""
|
||||
m = field_size_re.search(name)
|
||||
return int(m[1]) if m else None
|
||||
|
||||
|
||||
# This light wrapper "fakes" a dictionary interface, because some SQLite data
|
||||
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
|
||||
# as a simple dictionary lookup.
|
||||
class FlexibleFieldLookupDict:
|
||||
# Maps SQL types to Django Field types. Some of the SQL types have multiple
|
||||
# entries here because SQLite allows for anything and doesn't normalize the
|
||||
# field type; it uses whatever was given.
|
||||
base_data_types_reverse = {
|
||||
"bool": "BooleanField",
|
||||
"boolean": "BooleanField",
|
||||
"smallint": "SmallIntegerField",
|
||||
"smallint unsigned": "PositiveSmallIntegerField",
|
||||
"smallinteger": "SmallIntegerField",
|
||||
"int": "IntegerField",
|
||||
"integer": "IntegerField",
|
||||
"bigint": "BigIntegerField",
|
||||
"integer unsigned": "PositiveIntegerField",
|
||||
"bigint unsigned": "PositiveBigIntegerField",
|
||||
"decimal": "DecimalField",
|
||||
"real": "FloatField",
|
||||
"text": "TextField",
|
||||
"char": "CharField",
|
||||
"varchar": "CharField",
|
||||
"blob": "BinaryField",
|
||||
"date": "DateField",
|
||||
"datetime": "DateTimeField",
|
||||
"time": "TimeField",
|
||||
}
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = key.lower().split("(", 1)[0].strip()
|
||||
return self.base_data_types_reverse[key]
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
data_types_reverse = FlexibleFieldLookupDict()
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if description.pk and field_type in {
|
||||
"BigIntegerField",
|
||||
"IntegerField",
|
||||
"SmallIntegerField",
|
||||
}:
|
||||
# No support for BigAutoField or SmallAutoField as SQLite treats
|
||||
# all integer primary keys as signed 64-bit integers.
|
||||
return "AutoField"
|
||||
if description.has_json_constraint:
|
||||
return "JSONField"
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
# Skip the sqlite_sequence system table used for autoincrement key
|
||||
# generation.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name, type FROM sqlite_master
|
||||
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
|
||||
ORDER BY name"""
|
||||
)
|
||||
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
cursor.execute(
|
||||
"PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
table_info = cursor.fetchall()
|
||||
collations = self._get_column_collations(cursor, table_name)
|
||||
json_columns = set()
|
||||
if self.connection.features.can_introspect_json_field:
|
||||
for line in table_info:
|
||||
column = line[1]
|
||||
json_constraint_sql = '%%json_valid("%s")%%' % column
|
||||
has_json_constraint = cursor.execute(
|
||||
"""
|
||||
SELECT sql
|
||||
FROM sqlite_master
|
||||
WHERE
|
||||
type = 'table' AND
|
||||
name = %s AND
|
||||
sql LIKE %s
|
||||
""",
|
||||
[table_name, json_constraint_sql],
|
||||
).fetchone()
|
||||
if has_json_constraint:
|
||||
json_columns.add(column)
|
||||
return [
|
||||
FieldInfo(
|
||||
name,
|
||||
data_type,
|
||||
None,
|
||||
get_field_size(data_type),
|
||||
None,
|
||||
None,
|
||||
not notnull,
|
||||
default,
|
||||
collations.get(name),
|
||||
pk == 1,
|
||||
name in json_columns,
|
||||
)
|
||||
for cid, name, data_type, notnull, default, pk in table_info
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
pk_col = self.get_primary_key_column(cursor, table_name)
|
||||
return [{"table": table_name, "column": pk_col}]
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all relationships to the given table.
|
||||
"""
|
||||
# Dictionary of relations to return
|
||||
relations = {}
|
||||
|
||||
# Schema for this table
|
||||
cursor.execute(
|
||||
"SELECT sql, type FROM sqlite_master "
|
||||
"WHERE tbl_name = %s AND type IN ('table', 'view')",
|
||||
[table_name],
|
||||
)
|
||||
create_sql, table_type = cursor.fetchone()
|
||||
if table_type == "view":
|
||||
# It might be a view, then no results will be returned
|
||||
return relations
|
||||
results = create_sql[create_sql.index("(") + 1 : create_sql.rindex(")")]
|
||||
|
||||
# Walk through and look for references to other tables. SQLite doesn't
|
||||
# really have enforced references, but since it echoes out the SQL used
|
||||
# to create the table we can look for REFERENCES statements used there.
|
||||
for field_desc in results.split(","):
|
||||
field_desc = field_desc.strip()
|
||||
if field_desc.startswith("UNIQUE"):
|
||||
continue
|
||||
|
||||
m = re.search(r'references (\S*) ?\(["|]?(.*)["|]?\)', field_desc, re.I)
|
||||
if not m:
|
||||
continue
|
||||
table, column = [s.strip('"') for s in m.groups()]
|
||||
|
||||
if field_desc.startswith("FOREIGN KEY"):
|
||||
# Find name of the target FK field
|
||||
m = re.match(r"FOREIGN KEY\s*\(([^\)]*)\).*", field_desc, re.I)
|
||||
field_name = m[1].strip('"')
|
||||
else:
|
||||
field_name = field_desc.split()[0].strip('"')
|
||||
|
||||
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
|
||||
result = cursor.fetchall()[0]
|
||||
other_table_results = result[0].strip()
|
||||
li, ri = other_table_results.index("("), other_table_results.rindex(")")
|
||||
other_table_results = other_table_results[li + 1 : ri]
|
||||
|
||||
for other_desc in other_table_results.split(","):
|
||||
other_desc = other_desc.strip()
|
||||
if other_desc.startswith("UNIQUE"):
|
||||
continue
|
||||
|
||||
other_name = other_desc.split(" ", 1)[0].strip('"')
|
||||
if other_name == column:
|
||||
relations[field_name] = (other_name, table)
|
||||
break
|
||||
|
||||
return relations
|
||||
|
||||
def get_key_columns(self, cursor, table_name):
|
||||
"""
|
||||
Return a list of (column_name, referenced_table_name, referenced_column_name)
|
||||
for all key columns in given table.
|
||||
"""
|
||||
key_columns = []
|
||||
|
||||
# Schema for this table
|
||||
cursor.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s",
|
||||
[table_name, "table"],
|
||||
)
|
||||
results = cursor.fetchone()[0].strip()
|
||||
results = results[results.index("(") + 1 : results.rindex(")")]
|
||||
|
||||
# Walk through and look for references to other tables. SQLite doesn't
|
||||
# really have enforced references, but since it echoes out the SQL used
|
||||
# to create the table we can look for REFERENCES statements used there.
|
||||
for field_index, field_desc in enumerate(results.split(",")):
|
||||
field_desc = field_desc.strip()
|
||||
if field_desc.startswith("UNIQUE"):
|
||||
continue
|
||||
|
||||
m = re.search(r'"(.*)".*references (.*) \(["|](.*)["|]\)', field_desc, re.I)
|
||||
if not m:
|
||||
continue
|
||||
|
||||
# This will append
|
||||
# (column_name, referenced_table_name, referenced_column_name) to
|
||||
# key_columns.
|
||||
key_columns.append(tuple(s.strip('"') for s in m.groups()))
|
||||
|
||||
return key_columns
|
||||
|
||||
def get_primary_key_column(self, cursor, table_name):
|
||||
"""Return the column name of the primary key for the given table."""
|
||||
# Don't use PRAGMA because that causes issues with some transactions
|
||||
cursor.execute(
|
||||
"SELECT sql, type FROM sqlite_master "
|
||||
"WHERE tbl_name = %s AND type IN ('table', 'view')",
|
||||
[table_name],
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise ValueError("Table %s does not exist" % table_name)
|
||||
create_sql, table_type = row
|
||||
if table_type == "view":
|
||||
# Views don't have a primary key.
|
||||
return None
|
||||
fields_sql = create_sql[create_sql.index("(") + 1 : create_sql.rindex(")")]
|
||||
for field_desc in fields_sql.split(","):
|
||||
field_desc = field_desc.strip()
|
||||
m = re.match(
|
||||
r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc
|
||||
)
|
||||
if m:
|
||||
return m[1] if m[1] else m[2]
|
||||
return None
|
||||
|
||||
def _get_foreign_key_constraints(self, cursor, table_name):
|
||||
constraints = {}
|
||||
cursor.execute(
|
||||
"PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
# Remaining on_update/on_delete/match values are of no interest.
|
||||
id_, _, table, from_, to = row[:5]
|
||||
constraints["fk_%d" % id_] = {
|
||||
"columns": [from_],
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"foreign_key": (table, to),
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
return constraints
|
||||
|
||||
def _parse_column_or_constraint_definition(self, tokens, columns):
|
||||
token = None
|
||||
is_constraint_definition = None
|
||||
field_name = None
|
||||
constraint_name = None
|
||||
unique = False
|
||||
unique_columns = []
|
||||
check = False
|
||||
check_columns = []
|
||||
braces_deep = 0
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, "("):
|
||||
braces_deep += 1
|
||||
elif token.match(sqlparse.tokens.Punctuation, ")"):
|
||||
braces_deep -= 1
|
||||
if braces_deep < 0:
|
||||
# End of columns and constraints for table definition.
|
||||
break
|
||||
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
|
||||
# End of current column or constraint definition.
|
||||
break
|
||||
# Detect column or constraint definition by first token.
|
||||
if is_constraint_definition is None:
|
||||
is_constraint_definition = token.match(
|
||||
sqlparse.tokens.Keyword, "CONSTRAINT"
|
||||
)
|
||||
if is_constraint_definition:
|
||||
continue
|
||||
if is_constraint_definition:
|
||||
# Detect constraint name by second token.
|
||||
if constraint_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
constraint_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
constraint_name = token.value[1:-1]
|
||||
# Start constraint columns parsing after UNIQUE keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
|
||||
unique = True
|
||||
unique_braces_deep = braces_deep
|
||||
elif unique:
|
||||
if unique_braces_deep == braces_deep:
|
||||
if unique_columns:
|
||||
# Stop constraint parsing.
|
||||
unique = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
unique_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
unique_columns.append(token.value[1:-1])
|
||||
else:
|
||||
# Detect field name by first token.
|
||||
if field_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
field_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
field_name = token.value[1:-1]
|
||||
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
|
||||
unique_columns = [field_name]
|
||||
# Start constraint columns parsing after CHECK keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, "CHECK"):
|
||||
check = True
|
||||
check_braces_deep = braces_deep
|
||||
elif check:
|
||||
if check_braces_deep == braces_deep:
|
||||
if check_columns:
|
||||
# Stop constraint parsing.
|
||||
check = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
if token.value in columns:
|
||||
check_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
if token.value[1:-1] in columns:
|
||||
check_columns.append(token.value[1:-1])
|
||||
unique_constraint = (
|
||||
{
|
||||
"unique": True,
|
||||
"columns": unique_columns,
|
||||
"primary_key": False,
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
if unique_columns
|
||||
else None
|
||||
)
|
||||
check_constraint = (
|
||||
{
|
||||
"check": True,
|
||||
"columns": check_columns,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"foreign_key": None,
|
||||
"index": False,
|
||||
}
|
||||
if check_columns
|
||||
else None
|
||||
)
|
||||
return constraint_name, unique_constraint, check_constraint, token
|
||||
|
||||
def _parse_table_constraints(self, sql, columns):
|
||||
# Check constraint parsing is based of SQLite syntax diagram.
|
||||
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
constraints = {}
|
||||
unnamed_constrains_index = 0
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
# Go to columns and constraint definition
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, "("):
|
||||
break
|
||||
# Parse columns and constraint definition
|
||||
while True:
|
||||
(
|
||||
constraint_name,
|
||||
unique,
|
||||
check,
|
||||
end_token,
|
||||
) = self._parse_column_or_constraint_definition(tokens, columns)
|
||||
if unique:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = unique
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints[
|
||||
"__unnamed_constraint_%s__" % unnamed_constrains_index
|
||||
] = unique
|
||||
if check:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = check
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints[
|
||||
"__unnamed_constraint_%s__" % unnamed_constrains_index
|
||||
] = check
|
||||
if end_token.match(sqlparse.tokens.Punctuation, ")"):
|
||||
break
|
||||
return constraints
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Find inline check constraints.
|
||||
try:
|
||||
table_schema = cursor.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s"
|
||||
% (self.connection.ops.quote_name(table_name),)
|
||||
).fetchone()[0]
|
||||
except TypeError:
|
||||
# table_name is a view.
|
||||
pass
|
||||
else:
|
||||
columns = {
|
||||
info.name for info in self.get_table_description(cursor, table_name)
|
||||
}
|
||||
constraints.update(self._parse_table_constraints(table_schema, columns))
|
||||
|
||||
# Get the index info
|
||||
cursor.execute(
|
||||
"PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
|
||||
# columns. Discard last 2 columns if there.
|
||||
number, index, unique = row[:3]
|
||||
cursor.execute(
|
||||
"SELECT sql FROM sqlite_master "
|
||||
"WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
|
||||
)
|
||||
# There's at most one row.
|
||||
(sql,) = cursor.fetchone() or (None,)
|
||||
# Inline constraints are already detected in
|
||||
# _parse_table_constraints(). The reasons to avoid fetching inline
|
||||
# constraints from `PRAGMA index_list` are:
|
||||
# - Inline constraints can have a different name and information
|
||||
# than what `PRAGMA index_list` gives.
|
||||
# - Not all inline constraints may appear in `PRAGMA index_list`.
|
||||
if not sql:
|
||||
# An inline constraint
|
||||
continue
|
||||
# Get the index info for that index
|
||||
cursor.execute(
|
||||
"PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
|
||||
)
|
||||
for index_rank, column_rank, column in cursor.fetchall():
|
||||
if index not in constraints:
|
||||
constraints[index] = {
|
||||
"columns": [],
|
||||
"primary_key": False,
|
||||
"unique": bool(unique),
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
}
|
||||
constraints[index]["columns"].append(column)
|
||||
# Add type and column orders for indexes
|
||||
if constraints[index]["index"]:
|
||||
# SQLite doesn't support any index type other than b-tree
|
||||
constraints[index]["type"] = Index.suffix
|
||||
orders = self._get_index_columns_orders(sql)
|
||||
if orders is not None:
|
||||
constraints[index]["orders"] = orders
|
||||
# Get the PK
|
||||
pk_column = self.get_primary_key_column(cursor, table_name)
|
||||
if pk_column:
|
||||
# SQLite doesn't actually give a name to the PK constraint,
|
||||
# so we invent one. This is fine, as the SQLite backend never
|
||||
# deletes PK constraints by name, as you can't delete constraints
|
||||
# in SQLite; we remake the table with a new PK instead.
|
||||
constraints["__primary__"] = {
|
||||
"columns": [pk_column],
|
||||
"primary_key": True,
|
||||
"unique": False, # It's not actually a unique constraint.
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
constraints.update(self._get_foreign_key_constraints(cursor, table_name))
|
||||
return constraints
|
||||
|
||||
def _get_index_columns_orders(self, sql):
|
||||
tokens = sqlparse.parse(sql)[0]
|
||||
for token in tokens:
|
||||
if isinstance(token, sqlparse.sql.Parenthesis):
|
||||
columns = str(token).strip("()").split(", ")
|
||||
return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
|
||||
return None
|
||||
|
||||
def _get_column_collations(self, cursor, table_name):
|
||||
row = cursor.execute(
|
||||
"""
|
||||
SELECT sql
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table' AND name = %s
|
||||
""",
|
||||
[table_name],
|
||||
).fetchone()
|
||||
if not row:
|
||||
return {}
|
||||
|
||||
sql = row[0]
|
||||
columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
|
||||
collations = {}
|
||||
for column in columns:
|
||||
tokens = column[1:].split()
|
||||
column_name = tokens[0].strip('"')
|
||||
for index, token in enumerate(tokens):
|
||||
if token == "COLLATE":
|
||||
collation = tokens[index + 1]
|
||||
break
|
||||
else:
|
||||
collation = None
|
||||
collations[column_name] = collation
|
||||
return collations
|
||||
@@ -0,0 +1,416 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import DatabaseError, NotSupportedError, models
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.models.expressions import Col
|
||||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
cast_char_field_without_max_length = "text"
|
||||
cast_data_types = {
|
||||
"DateField": "TEXT",
|
||||
"DateTimeField": "TEXT",
|
||||
}
|
||||
explain_prefix = "EXPLAIN QUERY PLAN"
|
||||
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
|
||||
# SQLite. Use JSON_TYPE() instead.
|
||||
jsonfield_datatype_values = frozenset(["null", "false", "true"])
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""
|
||||
SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
|
||||
999 variables per query.
|
||||
|
||||
If there's only a single field to insert, the limit is 500
|
||||
(SQLITE_MAX_COMPOUND_SELECT).
|
||||
"""
|
||||
if len(fields) == 1:
|
||||
return 500
|
||||
elif len(fields) > 1:
|
||||
return self.connection.features.max_query_params // len(fields)
|
||||
else:
|
||||
return len(objs)
|
||||
|
||||
def check_expression_support(self, expression):
|
||||
bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
|
||||
bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
|
||||
if isinstance(expression, bad_aggregates):
|
||||
for expr in expression.get_source_expressions():
|
||||
try:
|
||||
output_field = expr.output_field
|
||||
except (AttributeError, FieldError):
|
||||
# Not every subexpression has an output_field which is fine
|
||||
# to ignore.
|
||||
pass
|
||||
else:
|
||||
if isinstance(output_field, bad_fields):
|
||||
raise NotSupportedError(
|
||||
"You cannot use Sum, Avg, StdDev, and Variance "
|
||||
"aggregations on date/time fields in sqlite3 "
|
||||
"since date/time is saved as text."
|
||||
)
|
||||
if (
|
||||
isinstance(expression, models.Aggregate)
|
||||
and expression.distinct
|
||||
and len(expression.source_expressions) > 1
|
||||
):
|
||||
raise NotSupportedError(
|
||||
"SQLite doesn't support DISTINCT on aggregate functions "
|
||||
"accepting multiple arguments."
|
||||
)
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
"""
|
||||
Support EXTRACT with a user-defined function django_date_extract()
|
||||
that's registered in connect(). Use single quotes because this is a
|
||||
string and could otherwise cause a collision with a field name.
|
||||
"""
|
||||
return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the list of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
"""Do nothing since formatting is handled in the custom function."""
|
||||
return sql
|
||||
|
||||
def date_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
return "django_date_trunc('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, field_name, tzname=None):
|
||||
return "django_time_trunc('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def _convert_tznames_to_sql(self, tzname):
|
||||
if tzname and settings.USE_TZ:
|
||||
return "'%s'" % tzname, "'%s'" % self.connection.timezone_name
|
||||
return "NULL", "NULL"
|
||||
|
||||
def datetime_cast_date_sql(self, field_name, tzname):
|
||||
return "django_datetime_cast_date(%s, %s, %s)" % (
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, field_name, tzname):
|
||||
return "django_datetime_cast_time(%s, %s, %s)" % (
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, field_name, tzname):
|
||||
return "django_datetime_extract('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, field_name, tzname):
|
||||
return "django_datetime_trunc('%s', %s, %s, %s)" % (
|
||||
lookup_type.lower(),
|
||||
field_name,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_extract_sql(self, lookup_type, field_name):
|
||||
return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name)
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def _quote_params_for_last_executed_query(self, params):
|
||||
"""
|
||||
Only for last_executed_query! Don't use this to execute SQL queries!
|
||||
"""
|
||||
# This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
|
||||
# number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
|
||||
# number of return values, default = 2000). Since Python's sqlite3
|
||||
# module doesn't expose the get_limit() C API, assume the default
|
||||
# limits are in effect and split the work in batches if needed.
|
||||
BATCH_SIZE = 999
|
||||
if len(params) > BATCH_SIZE:
|
||||
results = ()
|
||||
for index in range(0, len(params), BATCH_SIZE):
|
||||
chunk = params[index : index + BATCH_SIZE]
|
||||
results += self._quote_params_for_last_executed_query(chunk)
|
||||
return results
|
||||
|
||||
sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
|
||||
# Bypass Django's wrappers and use the underlying sqlite3 connection
|
||||
# to avoid logging this query - it would trigger infinite recursion.
|
||||
cursor = self.connection.connection.cursor()
|
||||
# Native sqlite3 cursors cannot be used as context managers.
|
||||
try:
|
||||
return cursor.execute(sql, params).fetchone()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
|
||||
# pysqlite_statement_bind_parameters(
|
||||
# self->statement, parameters, allow_8bit_chars
|
||||
# );
|
||||
# Unfortunately there is no way to reach self->statement from Python,
|
||||
# so we quote and substitute parameters manually.
|
||||
if params:
|
||||
if isinstance(params, (list, tuple)):
|
||||
params = self._quote_params_for_last_executed_query(params)
|
||||
else:
|
||||
values = tuple(params.values())
|
||||
values = self._quote_params_for_last_executed_query(values)
|
||||
params = dict(zip(params, values))
|
||||
return sql % params
|
||||
# For consistency with SQLiteCursorWrapper.execute(), just return sql
|
||||
# when there are no parameters. See #13648 and #17158.
|
||||
else:
|
||||
return sql
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith('"') and name.endswith('"'):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def no_limit_value(self):
|
||||
return -1
|
||||
|
||||
def __references_graph(self, table_name):
|
||||
query = """
|
||||
WITH tables AS (
|
||||
SELECT %s name
|
||||
UNION
|
||||
SELECT sqlite_master.name
|
||||
FROM sqlite_master
|
||||
JOIN tables ON (sql REGEXP %s || tables.name || %s)
|
||||
) SELECT name FROM tables;
|
||||
"""
|
||||
params = (
|
||||
table_name,
|
||||
r'(?i)\s+references\s+("|\')?',
|
||||
r'("|\')?\s*\(',
|
||||
)
|
||||
with self.connection.cursor() as cursor:
|
||||
results = cursor.execute(query, params)
|
||||
return [row[0] for row in results.fetchall()]
|
||||
|
||||
@cached_property
|
||||
def _references_graph(self):
|
||||
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||
# Django's test suite.
|
||||
return lru_cache(maxsize=512)(self.__references_graph)
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if tables and allow_cascade:
|
||||
# Simulate TRUNCATE CASCADE by recursively collecting the tables
|
||||
# referencing the tables to be flushed.
|
||||
tables = set(
|
||||
chain.from_iterable(self._references_graph(table) for table in tables)
|
||||
)
|
||||
sql = [
|
||||
"%s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("DELETE"),
|
||||
style.SQL_KEYWORD("FROM"),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
)
|
||||
for table in tables
|
||||
]
|
||||
if reset_sequences:
|
||||
sequences = [{"table": table} for table in tables]
|
||||
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
if not sequences:
|
||||
return []
|
||||
return [
|
||||
"%s %s %s %s = 0 %s %s %s (%s);"
|
||||
% (
|
||||
style.SQL_KEYWORD("UPDATE"),
|
||||
style.SQL_TABLE(self.quote_name("sqlite_sequence")),
|
||||
style.SQL_KEYWORD("SET"),
|
||||
style.SQL_FIELD(self.quote_name("seq")),
|
||||
style.SQL_KEYWORD("WHERE"),
|
||||
style.SQL_FIELD(self.quote_name("name")),
|
||||
style.SQL_KEYWORD("IN"),
|
||||
", ".join(
|
||||
["'%s'" % sequence_info["table"] for sequence_info in sequences]
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError(
|
||||
"SQLite backend does not support timezone-aware datetimes when "
|
||||
"USE_TZ is False."
|
||||
)
|
||||
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("SQLite backend does not support timezone-aware times.")
|
||||
|
||||
return str(value)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == "DateTimeField":
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == "DateField":
|
||||
converters.append(self.convert_datefield_value)
|
||||
elif internal_type == "TimeField":
|
||||
converters.append(self.convert_timefield_value)
|
||||
elif internal_type == "DecimalField":
|
||||
converters.append(self.get_decimalfield_converter(expression))
|
||||
elif internal_type == "UUIDField":
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
elif internal_type == "BooleanField":
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
return converters
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.datetime):
|
||||
value = parse_datetime(value)
|
||||
if settings.USE_TZ and not timezone.is_aware(value):
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_datefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.date):
|
||||
value = parse_date(value)
|
||||
return value
|
||||
|
||||
def convert_timefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.time):
|
||||
value = parse_time(value)
|
||||
return value
|
||||
|
||||
def get_decimalfield_converter(self, expression):
|
||||
# SQLite stores only 15 significant digits. Digits coming from
|
||||
# float inaccuracy must be removed.
|
||||
create_decimal = decimal.Context(prec=15).create_decimal_from_float
|
||||
if isinstance(expression, Col):
|
||||
quantize_value = decimal.Decimal(1).scaleb(
|
||||
-expression.output_field.decimal_places
|
||||
)
|
||||
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value).quantize(
|
||||
quantize_value, context=expression.output_field.context
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value)
|
||||
|
||||
return converter
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
return bool(value) if value in (1, 0) else value
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
return " UNION ALL ".join(
|
||||
"SELECT %s" % ", ".join(row) for row in placeholder_rows
|
||||
)
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
# SQLite doesn't have a ^ operator, so use the user-defined POWER
|
||||
# function that's registered in connect().
|
||||
if connector == "^":
|
||||
return "POWER(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "#":
|
||||
return "BITXOR(%s)" % ",".join(sub_expressions)
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def combine_duration_expression(self, connector, sub_expressions):
|
||||
if connector not in ["+", "-", "*", "/"]:
|
||||
raise DatabaseError("Invalid connector for timedelta: %s." % connector)
|
||||
fn_params = ["'%s'" % connector] + sub_expressions
|
||||
if len(fn_params) > 3:
|
||||
raise ValueError("Too many params for timedelta operations.")
|
||||
return "django_format_dtdelta(%s)" % ", ".join(fn_params)
|
||||
|
||||
def integer_field_range(self, internal_type):
|
||||
# SQLite doesn't enforce any integer constraints
|
||||
return (None, None)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
if internal_type == "TimeField":
|
||||
return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
|
||||
return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
|
||||
|
||||
def insert_statement(self, ignore_conflicts=False):
|
||||
return (
|
||||
"INSERT OR IGNORE INTO"
|
||||
if ignore_conflicts
|
||||
else super().insert_statement(ignore_conflicts)
|
||||
)
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
|
||||
if not fields:
|
||||
return "", ()
|
||||
columns = [
|
||||
"%s.%s"
|
||||
% (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
)
|
||||
for field in fields
|
||||
]
|
||||
return "RETURNING %s" % ", ".join(columns), ()
|
||||
@@ -0,0 +1,531 @@
|
||||
import copy
|
||||
from decimal import Decimal
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import NotSupportedError
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.backends.ddl_references import Statement
|
||||
from django.db.backends.utils import strip_quotes
|
||||
from django.db.models import UniqueConstraint
|
||||
from django.db.transaction import atomic
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
|
||||
sql_delete_table = "DROP TABLE %(table)s"
|
||||
sql_create_fk = None
|
||||
sql_create_inline_fk = (
|
||||
"REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
|
||||
)
|
||||
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
|
||||
sql_delete_unique = "DROP INDEX %(name)s"
|
||||
|
||||
def __enter__(self):
|
||||
# Some SQLite schema alterations need foreign key constraints to be
|
||||
# disabled. Enforce it here for the duration of the schema edition.
|
||||
if not self.connection.disable_constraint_checking():
|
||||
raise NotSupportedError(
|
||||
"SQLite schema editor cannot be used while foreign key "
|
||||
"constraint checks are enabled. Make sure to disable them "
|
||||
"before entering a transaction.atomic() context because "
|
||||
"SQLite does not support disabling them in the middle of "
|
||||
"a multi-statement transaction."
|
||||
)
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.connection.check_constraints()
|
||||
super().__exit__(exc_type, exc_value, traceback)
|
||||
self.connection.enable_constraint_checking()
|
||||
|
||||
def quote_value(self, value):
|
||||
# The backend "mostly works" without this function and there are use
|
||||
# cases for compiling Python without the sqlite3 libraries (e.g.
|
||||
# security hardening).
|
||||
try:
|
||||
import sqlite3
|
||||
|
||||
value = sqlite3.adapt(value)
|
||||
except ImportError:
|
||||
pass
|
||||
except sqlite3.ProgrammingError:
|
||||
pass
|
||||
# Manual emulation of SQLite parameter quoting
|
||||
if isinstance(value, bool):
|
||||
return str(int(value))
|
||||
elif isinstance(value, (Decimal, float, int)):
|
||||
return str(value)
|
||||
elif isinstance(value, str):
|
||||
return "'%s'" % value.replace("'", "''")
|
||||
elif value is None:
|
||||
return "NULL"
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# Bytes are only allowed for BLOB fields, encoded as string
|
||||
# literals containing hexadecimal data and preceded by a single "X"
|
||||
# character.
|
||||
return "X'%s'" % value.hex()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot quote parameter value %r of type %s" % (value, type(value))
|
||||
)
|
||||
|
||||
def _is_referenced_by_fk_constraint(
|
||||
self, table_name, column_name=None, ignore_self=False
|
||||
):
|
||||
"""
|
||||
Return whether or not the provided table name is referenced by another
|
||||
one. If `column_name` is specified, only references pointing to that
|
||||
column are considered. If `ignore_self` is True, self-referential
|
||||
constraints are ignored.
|
||||
"""
|
||||
with self.connection.cursor() as cursor:
|
||||
for other_table in self.connection.introspection.get_table_list(cursor):
|
||||
if ignore_self and other_table.name == table_name:
|
||||
continue
|
||||
constraints = (
|
||||
self.connection.introspection._get_foreign_key_constraints(
|
||||
cursor, other_table.name
|
||||
)
|
||||
)
|
||||
for constraint in constraints.values():
|
||||
constraint_table, constraint_column = constraint["foreign_key"]
|
||||
if constraint_table == table_name and (
|
||||
column_name is None or constraint_column == column_name
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def alter_db_table(
|
||||
self, model, old_db_table, new_db_table, disable_constraints=True
|
||||
):
|
||||
if (
|
||||
not self.connection.features.supports_atomic_references_rename
|
||||
and disable_constraints
|
||||
and self._is_referenced_by_fk_constraint(old_db_table)
|
||||
):
|
||||
if self.connection.in_atomic_block:
|
||||
raise NotSupportedError(
|
||||
(
|
||||
"Renaming the %r table while in a transaction is not "
|
||||
"supported on SQLite < 3.26 because it would break referential "
|
||||
"integrity. Try adding `atomic = False` to the Migration class."
|
||||
)
|
||||
% old_db_table
|
||||
)
|
||||
self.connection.enable_constraint_checking()
|
||||
super().alter_db_table(model, old_db_table, new_db_table)
|
||||
self.connection.disable_constraint_checking()
|
||||
else:
|
||||
super().alter_db_table(model, old_db_table, new_db_table)
|
||||
|
||||
def alter_field(self, model, old_field, new_field, strict=False):
|
||||
if not self._field_should_be_altered(old_field, new_field):
|
||||
return
|
||||
old_field_name = old_field.name
|
||||
table_name = model._meta.db_table
|
||||
_, old_column_name = old_field.get_attname_column()
|
||||
if (
|
||||
new_field.name != old_field_name
|
||||
and not self.connection.features.supports_atomic_references_rename
|
||||
and self._is_referenced_by_fk_constraint(
|
||||
table_name, old_column_name, ignore_self=True
|
||||
)
|
||||
):
|
||||
if self.connection.in_atomic_block:
|
||||
raise NotSupportedError(
|
||||
(
|
||||
"Renaming the %r.%r column while in a transaction is not "
|
||||
"supported on SQLite < 3.26 because it would break referential "
|
||||
"integrity. Try adding `atomic = False` to the Migration class."
|
||||
)
|
||||
% (model._meta.db_table, old_field_name)
|
||||
)
|
||||
with atomic(self.connection.alias):
|
||||
super().alter_field(model, old_field, new_field, strict=strict)
|
||||
# Follow SQLite's documented procedure for performing changes
|
||||
# that don't affect the on-disk content.
|
||||
# https://sqlite.org/lang_altertable.html#otheralter
|
||||
with self.connection.cursor() as cursor:
|
||||
schema_version = cursor.execute("PRAGMA schema_version").fetchone()[
|
||||
0
|
||||
]
|
||||
cursor.execute("PRAGMA writable_schema = 1")
|
||||
references_template = ' REFERENCES "%s" ("%%s") ' % table_name
|
||||
new_column_name = new_field.get_attname_column()[1]
|
||||
search = references_template % old_column_name
|
||||
replacement = references_template % new_column_name
|
||||
cursor.execute(
|
||||
"UPDATE sqlite_master SET sql = replace(sql, %s, %s)",
|
||||
(search, replacement),
|
||||
)
|
||||
cursor.execute("PRAGMA schema_version = %d" % (schema_version + 1))
|
||||
cursor.execute("PRAGMA writable_schema = 0")
|
||||
# The integrity check will raise an exception and rollback
|
||||
# the transaction if the sqlite_master updates corrupt the
|
||||
# database.
|
||||
cursor.execute("PRAGMA integrity_check")
|
||||
# Perform a VACUUM to refresh the database representation from
|
||||
# the sqlite_master table.
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("VACUUM")
|
||||
else:
|
||||
super().alter_field(model, old_field, new_field, strict=strict)
|
||||
|
||||
def _remake_table(
|
||||
self, model, create_field=None, delete_field=None, alter_field=None
|
||||
):
|
||||
"""
|
||||
Shortcut to transform a model from old_model into new_model
|
||||
|
||||
This follows the correct procedure to perform non-rename or column
|
||||
addition operations based on SQLite's documentation
|
||||
|
||||
https://www.sqlite.org/lang_altertable.html#caution
|
||||
|
||||
The essential steps are:
|
||||
1. Create a table with the updated definition called "new__app_model"
|
||||
2. Copy the data from the existing "app_model" table to the new table
|
||||
3. Drop the "app_model" table
|
||||
4. Rename the "new__app_model" table to "app_model"
|
||||
5. Restore any index of the previous "app_model" table.
|
||||
"""
|
||||
# Self-referential fields must be recreated rather than copied from
|
||||
# the old model to ensure their remote_field.field_name doesn't refer
|
||||
# to an altered field.
|
||||
def is_self_referential(f):
|
||||
return f.is_relation and f.remote_field.model is model
|
||||
|
||||
# Work out the new fields dict / mapping
|
||||
body = {
|
||||
f.name: f.clone() if is_self_referential(f) else f
|
||||
for f in model._meta.local_concrete_fields
|
||||
}
|
||||
# Since mapping might mix column names and default values,
|
||||
# its values must be already quoted.
|
||||
mapping = {
|
||||
f.column: self.quote_name(f.column)
|
||||
for f in model._meta.local_concrete_fields
|
||||
}
|
||||
# This maps field names (not columns) for things like unique_together
|
||||
rename_mapping = {}
|
||||
# If any of the new or altered fields is introducing a new PK,
|
||||
# remove the old one
|
||||
restore_pk_field = None
|
||||
if getattr(create_field, "primary_key", False) or (
|
||||
alter_field and getattr(alter_field[1], "primary_key", False)
|
||||
):
|
||||
for name, field in list(body.items()):
|
||||
if field.primary_key:
|
||||
field.primary_key = False
|
||||
restore_pk_field = field
|
||||
if field.auto_created:
|
||||
del body[name]
|
||||
del mapping[field.column]
|
||||
# Add in any created fields
|
||||
if create_field:
|
||||
body[create_field.name] = create_field
|
||||
# Choose a default and insert it into the copy map
|
||||
if not create_field.many_to_many and create_field.concrete:
|
||||
mapping[create_field.column] = self.quote_value(
|
||||
self.effective_default(create_field)
|
||||
)
|
||||
# Add in any altered fields
|
||||
if alter_field:
|
||||
old_field, new_field = alter_field
|
||||
body.pop(old_field.name, None)
|
||||
mapping.pop(old_field.column, None)
|
||||
body[new_field.name] = new_field
|
||||
if old_field.null and not new_field.null:
|
||||
case_sql = "coalesce(%(col)s, %(default)s)" % {
|
||||
"col": self.quote_name(old_field.column),
|
||||
"default": self.quote_value(self.effective_default(new_field)),
|
||||
}
|
||||
mapping[new_field.column] = case_sql
|
||||
else:
|
||||
mapping[new_field.column] = self.quote_name(old_field.column)
|
||||
rename_mapping[old_field.name] = new_field.name
|
||||
# Remove any deleted fields
|
||||
if delete_field:
|
||||
del body[delete_field.name]
|
||||
del mapping[delete_field.column]
|
||||
# Remove any implicit M2M tables
|
||||
if (
|
||||
delete_field.many_to_many
|
||||
and delete_field.remote_field.through._meta.auto_created
|
||||
):
|
||||
return self.delete_model(delete_field.remote_field.through)
|
||||
# Work inside a new app registry
|
||||
apps = Apps()
|
||||
|
||||
# Work out the new value of unique_together, taking renames into
|
||||
# account
|
||||
unique_together = [
|
||||
[rename_mapping.get(n, n) for n in unique]
|
||||
for unique in model._meta.unique_together
|
||||
]
|
||||
|
||||
# Work out the new value for index_together, taking renames into
|
||||
# account
|
||||
index_together = [
|
||||
[rename_mapping.get(n, n) for n in index]
|
||||
for index in model._meta.index_together
|
||||
]
|
||||
|
||||
indexes = model._meta.indexes
|
||||
if delete_field:
|
||||
indexes = [
|
||||
index for index in indexes if delete_field.name not in index.fields
|
||||
]
|
||||
|
||||
constraints = list(model._meta.constraints)
|
||||
|
||||
# Provide isolated instances of the fields to the new model body so
|
||||
# that the existing model's internals aren't interfered with when
|
||||
# the dummy model is constructed.
|
||||
body_copy = copy.deepcopy(body)
|
||||
|
||||
# Construct a new model with the new fields to allow self referential
|
||||
# primary key to resolve to. This model won't ever be materialized as a
|
||||
# table and solely exists for foreign key reference resolution purposes.
|
||||
# This wouldn't be required if the schema editor was operating on model
|
||||
# states instead of rendered models.
|
||||
meta_contents = {
|
||||
"app_label": model._meta.app_label,
|
||||
"db_table": model._meta.db_table,
|
||||
"unique_together": unique_together,
|
||||
"index_together": index_together,
|
||||
"indexes": indexes,
|
||||
"constraints": constraints,
|
||||
"apps": apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy["Meta"] = meta
|
||||
body_copy["__module__"] = model.__module__
|
||||
type(model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Construct a model with a renamed table name.
|
||||
body_copy = copy.deepcopy(body)
|
||||
meta_contents = {
|
||||
"app_label": model._meta.app_label,
|
||||
"db_table": "new__%s" % strip_quotes(model._meta.db_table),
|
||||
"unique_together": unique_together,
|
||||
"index_together": index_together,
|
||||
"indexes": indexes,
|
||||
"constraints": constraints,
|
||||
"apps": apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy["Meta"] = meta
|
||||
body_copy["__module__"] = model.__module__
|
||||
new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Create a new table with the updated schema.
|
||||
self.create_model(new_model)
|
||||
|
||||
# Copy data from the old table into the new table
|
||||
self.execute(
|
||||
"INSERT INTO %s (%s) SELECT %s FROM %s"
|
||||
% (
|
||||
self.quote_name(new_model._meta.db_table),
|
||||
", ".join(self.quote_name(x) for x in mapping),
|
||||
", ".join(mapping.values()),
|
||||
self.quote_name(model._meta.db_table),
|
||||
)
|
||||
)
|
||||
|
||||
# Delete the old table to make way for the new
|
||||
self.delete_model(model, handle_autom2m=False)
|
||||
|
||||
# Rename the new table to take way for the old
|
||||
self.alter_db_table(
|
||||
new_model,
|
||||
new_model._meta.db_table,
|
||||
model._meta.db_table,
|
||||
disable_constraints=False,
|
||||
)
|
||||
|
||||
# Run deferred SQL on correct table
|
||||
for sql in self.deferred_sql:
|
||||
self.execute(sql)
|
||||
self.deferred_sql = []
|
||||
# Fix any PK-removed field
|
||||
if restore_pk_field:
|
||||
restore_pk_field.primary_key = True
|
||||
|
||||
def delete_model(self, model, handle_autom2m=True):
|
||||
if handle_autom2m:
|
||||
super().delete_model(model)
|
||||
else:
|
||||
# Delete the table (and only that)
|
||||
self.execute(
|
||||
self.sql_delete_table
|
||||
% {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
}
|
||||
)
|
||||
# Remove all deferred statements referencing the deleted table.
|
||||
for sql in list(self.deferred_sql):
|
||||
if isinstance(sql, Statement) and sql.references_table(
|
||||
model._meta.db_table
|
||||
):
|
||||
self.deferred_sql.remove(sql)
|
||||
|
||||
def add_field(self, model, field):
|
||||
"""
|
||||
Create a field on a model. Usually involves adding a column, but may
|
||||
involve adding a table instead (for M2M fields).
|
||||
"""
|
||||
# Special-case implicit M2M tables
|
||||
if field.many_to_many and field.remote_field.through._meta.auto_created:
|
||||
return self.create_model(field.remote_field.through)
|
||||
self._remake_table(model, create_field=field)
|
||||
|
||||
def remove_field(self, model, field):
|
||||
"""
|
||||
Remove a field from a model. Usually involves deleting a column,
|
||||
but for M2Ms may involve deleting a table.
|
||||
"""
|
||||
# M2M fields are a special case
|
||||
if field.many_to_many:
|
||||
# For implicit M2M tables, delete the auto-created table
|
||||
if field.remote_field.through._meta.auto_created:
|
||||
self.delete_model(field.remote_field.through)
|
||||
# For explicit "through" M2M fields, do nothing
|
||||
# For everything else, remake.
|
||||
else:
|
||||
# It might not actually have a column behind it
|
||||
if field.db_parameters(connection=self.connection)["type"] is None:
|
||||
return
|
||||
self._remake_table(model, delete_field=field)
|
||||
|
||||
def _alter_field(
|
||||
self,
|
||||
model,
|
||||
old_field,
|
||||
new_field,
|
||||
old_type,
|
||||
new_type,
|
||||
old_db_params,
|
||||
new_db_params,
|
||||
strict=False,
|
||||
):
|
||||
"""Perform a "physical" (non-ManyToMany) field update."""
|
||||
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
|
||||
# changed and there aren't any constraints.
|
||||
if (
|
||||
self.connection.features.can_alter_table_rename_column
|
||||
and old_field.column != new_field.column
|
||||
and self.column_sql(model, old_field) == self.column_sql(model, new_field)
|
||||
and not (
|
||||
old_field.remote_field
|
||||
and old_field.db_constraint
|
||||
or new_field.remote_field
|
||||
and new_field.db_constraint
|
||||
)
|
||||
):
|
||||
return self.execute(
|
||||
self._rename_field_sql(
|
||||
model._meta.db_table, old_field, new_field, new_type
|
||||
)
|
||||
)
|
||||
# Alter by remaking table
|
||||
self._remake_table(model, alter_field=(old_field, new_field))
|
||||
# Rebuild tables with FKs pointing to this field.
|
||||
if new_field.unique and old_type != new_type:
|
||||
related_models = set()
|
||||
opts = new_field.model._meta
|
||||
for remote_field in opts.related_objects:
|
||||
# Ignore self-relationship since the table was already rebuilt.
|
||||
if remote_field.related_model == model:
|
||||
continue
|
||||
if not remote_field.many_to_many:
|
||||
if remote_field.field_name == new_field.name:
|
||||
related_models.add(remote_field.related_model)
|
||||
elif new_field.primary_key and remote_field.through._meta.auto_created:
|
||||
related_models.add(remote_field.through)
|
||||
if new_field.primary_key:
|
||||
for many_to_many in opts.many_to_many:
|
||||
# Ignore self-relationship since the table was already rebuilt.
|
||||
if many_to_many.related_model == model:
|
||||
continue
|
||||
if many_to_many.remote_field.through._meta.auto_created:
|
||||
related_models.add(many_to_many.remote_field.through)
|
||||
for related_model in related_models:
|
||||
self._remake_table(related_model)
|
||||
|
||||
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
||||
"""Alter M2Ms to repoint their to= endpoints."""
|
||||
if (
|
||||
old_field.remote_field.through._meta.db_table
|
||||
== new_field.remote_field.through._meta.db_table
|
||||
):
|
||||
# The field name didn't change, but some options did, so we have to
|
||||
# propagate this altering.
|
||||
self._remake_table(
|
||||
old_field.remote_field.through,
|
||||
alter_field=(
|
||||
# The field that points to the target model is needed, so
|
||||
# we can tell alter_field to change it - this is
|
||||
# m2m_reverse_field_name() (as opposed to m2m_field_name(),
|
||||
# which points to our model).
|
||||
old_field.remote_field.through._meta.get_field(
|
||||
old_field.m2m_reverse_field_name()
|
||||
),
|
||||
new_field.remote_field.through._meta.get_field(
|
||||
new_field.m2m_reverse_field_name()
|
||||
),
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
# Make a new through table
|
||||
self.create_model(new_field.remote_field.through)
|
||||
# Copy the data across
|
||||
self.execute(
|
||||
"INSERT INTO %s (%s) SELECT %s FROM %s"
|
||||
% (
|
||||
self.quote_name(new_field.remote_field.through._meta.db_table),
|
||||
", ".join(
|
||||
[
|
||||
"id",
|
||||
new_field.m2m_column_name(),
|
||||
new_field.m2m_reverse_name(),
|
||||
]
|
||||
),
|
||||
", ".join(
|
||||
[
|
||||
"id",
|
||||
old_field.m2m_column_name(),
|
||||
old_field.m2m_reverse_name(),
|
||||
]
|
||||
),
|
||||
self.quote_name(old_field.remote_field.through._meta.db_table),
|
||||
)
|
||||
)
|
||||
# Delete the old through table
|
||||
self.delete_model(old_field.remote_field.through)
|
||||
|
||||
def add_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and (
|
||||
constraint.condition
|
||||
or constraint.contains_expressions
|
||||
or constraint.include
|
||||
or constraint.deferrable
|
||||
):
|
||||
super().add_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
|
||||
def remove_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and (
|
||||
constraint.condition
|
||||
or constraint.contains_expressions
|
||||
or constraint.include
|
||||
or constraint.deferrable
|
||||
):
|
||||
super().remove_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
|
||||
def _collate_sql(self, collation):
|
||||
return "COLLATE " + collation
|
||||
Reference in New Issue
Block a user