测试gitnore

This commit is contained in:
ladeng07
2022-05-06 15:45:57 +08:00
parent 12f390949b
commit 51552904f9
2347 changed files with 120102 additions and 53549 deletions
File diff suppressed because it is too large Load Diff
@@ -3,37 +3,31 @@ from django.db import DatabaseError
class AmbiguityError(Exception):
"""More than one migration matches a name prefix."""
pass
class BadMigrationError(Exception):
"""There's a bad migration (unreadable/bad format/etc.)."""
pass
class CircularDependencyError(Exception):
"""There's an impossible-to-resolve circular dependency."""
pass
class InconsistentMigrationHistory(Exception):
"""An applied migration has some of its dependencies not applied."""
pass
class InvalidBasesError(ValueError):
"""A model's base classes can't be resolved."""
pass
class IrreversibleError(RuntimeError):
"""An irreversible migration is about to be reversed."""
pass
@@ -40,22 +40,13 @@ class MigrationExecutor:
# If the migration is already applied, do backwards mode,
# otherwise do forwards mode.
elif target in applied:
# If the target is missing, it's likely a replaced migration.
# Reload the graph without replacements.
if (
self.loader.replace_migrations
and target not in self.loader.graph.node_map
):
self.loader.replace_migrations = False
self.loader.build_graph()
return self.migration_plan(targets, clean_start=clean_start)
# Don't migrate backwards all the way to the target node (that
# may roll back dependencies in other apps that don't need to
# be rolled back); instead roll back through target's immediate
# child(ren) in the same app, and no further.
next_in_app = sorted(
n
for n in self.loader.graph.node_map[target].children
n for n in
self.loader.graph.node_map[target].children
if n[0] == target[0]
)
for node in next_in_app:
@@ -75,15 +66,12 @@ class MigrationExecutor:
Create a project state including all the applications without
migrations and applied migrations if with_applied_migrations=True.
"""
state = ProjectState(real_apps=self.loader.unmigrated_apps)
state = ProjectState(real_apps=list(self.loader.unmigrated_apps))
if with_applied_migrations:
# Create the forwards plan Django would follow on an empty database
full_plan = self.migration_plan(
self.loader.graph.leaf_nodes(), clean_start=True
)
full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
applied_migrations = {
self.loader.graph.nodes[key]
for key in self.loader.applied_migrations
self.loader.graph.nodes[key] for key in self.loader.applied_migrations
if key in self.loader.graph.nodes
}
for migration, _ in full_plan:
@@ -105,9 +93,7 @@ class MigrationExecutor:
if plan is None:
plan = self.migration_plan(targets)
# Create the forwards plan Django would follow on an empty database
full_plan = self.migration_plan(
self.loader.graph.leaf_nodes(), clean_start=True
)
full_plan = self.migration_plan(self.loader.graph.leaf_nodes(), clean_start=True)
all_forwards = all(not backwards for mig, backwards in plan)
all_backwards = all(backwards for mig, backwards in plan)
@@ -122,15 +108,13 @@ class MigrationExecutor:
"Migration plans with both forwards and backwards migrations "
"are not supported. Please split your migration process into "
"separate plans of only forwards OR backwards migrations.",
plan,
plan
)
elif all_forwards:
if state is None:
# The resulting state should still include applied migrations.
state = self._create_project_state(with_applied_migrations=True)
state = self._migrate_all_forwards(
state, plan, full_plan, fake=fake, fake_initial=fake_initial
)
state = self._migrate_all_forwards(state, plan, full_plan, fake=fake, fake_initial=fake_initial)
else:
# No need to check for `elif all_backwards` here, as that condition
# would always evaluate to true.
@@ -154,15 +138,13 @@ class MigrationExecutor:
# process.
break
if migration in migrations_to_run:
if "apps" not in state.__dict__:
if 'apps' not in state.__dict__:
if self.progress_callback:
self.progress_callback("render_start")
state.apps # Render all -- performance critical
if self.progress_callback:
self.progress_callback("render_success")
state = self.apply_migration(
state, migration, fake=fake, fake_initial=fake_initial
)
state = self.apply_migration(state, migration, fake=fake, fake_initial=fake_initial)
migrations_to_run.remove(migration)
return state
@@ -182,8 +164,7 @@ class MigrationExecutor:
states = {}
state = self._create_project_state()
applied_migrations = {
self.loader.graph.nodes[key]
for key in self.loader.applied_migrations
self.loader.graph.nodes[key] for key in self.loader.applied_migrations
if key in self.loader.graph.nodes
}
if self.progress_callback:
@@ -196,7 +177,7 @@ class MigrationExecutor:
# process.
break
if migration in migrations_to_run:
if "apps" not in state.__dict__:
if 'apps' not in state.__dict__:
state.apps # Render all -- performance critical
# The state before this migration
states[migration] = state
@@ -242,9 +223,7 @@ class MigrationExecutor:
fake = True
if not fake:
# Alright, do it normally
with self.connection.schema_editor(
atomic=migration.atomic
) as schema_editor:
with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
state = migration.apply(state, schema_editor)
if not schema_editor.deferred_sql:
self.record_migration(migration)
@@ -269,15 +248,14 @@ class MigrationExecutor:
if self.progress_callback:
self.progress_callback("unapply_start", migration, fake)
if not fake:
with self.connection.schema_editor(
atomic=migration.atomic
) as schema_editor:
with self.connection.schema_editor(atomic=migration.atomic) as schema_editor:
state = migration.unapply(state, schema_editor)
# For replacement migrations, also record individual statuses.
# For replacement migrations, record individual statuses
if migration.replaces:
for app_label, name in migration.replaces:
self.recorder.record_unapplied(app_label, name)
self.recorder.record_unapplied(migration.app_label, migration.name)
else:
self.recorder.record_unapplied(migration.app_label, migration.name)
# Report progress
if self.progress_callback:
self.progress_callback("unapply_success", migration, fake)
@@ -306,18 +284,15 @@ class MigrationExecutor:
tables or columns it would create exist. This is intended only for use
on initial migrations (as it only looks for CreateModel and AddField).
"""
def should_skip_detecting_model(migration, model):
"""
No need to detect tables for proxy models, unmanaged models, or
models that can't be migrated on the current database.
"""
return (
model._meta.proxy
or not model._meta.managed
or not router.allow_migrate(
self.connection.alias,
migration.app_label,
model._meta.proxy or not model._meta.managed or not
router.allow_migrate(
self.connection.alias, migration.app_label,
model_name=model._meta.model_name,
)
)
@@ -331,9 +306,7 @@ class MigrationExecutor:
return False, project_state
if project_state is None:
after_state = self.loader.project_state(
(migration.app_label, migration.name), at_end=True
)
after_state = self.loader.project_state((migration.app_label, migration.name), at_end=True)
else:
after_state = migration.mutate_state(project_state)
apps = after_state.apps
@@ -341,13 +314,9 @@ class MigrationExecutor:
found_add_field_migration = False
fold_identifier_case = self.connection.features.ignores_table_name_case
with self.connection.cursor() as cursor:
existing_table_names = set(
self.connection.introspection.table_names(cursor)
)
existing_table_names = set(self.connection.introspection.table_names(cursor))
if fold_identifier_case:
existing_table_names = {
name.casefold() for name in existing_table_names
}
existing_table_names = {name.casefold() for name in existing_table_names}
# Make sure all create model and add field operations are done
for operation in migration.operations:
if isinstance(operation, migrations.CreateModel):
@@ -387,9 +356,7 @@ class MigrationExecutor:
found_add_field_migration = True
continue
with self.connection.cursor() as cursor:
columns = self.connection.introspection.get_table_description(
cursor, table
)
columns = self.connection.introspection.get_table_description(cursor, table)
for column in columns:
field_column = field.column
column_name = column.name
@@ -401,6 +368,6 @@ class MigrationExecutor:
break
else:
return False, project_state
# If we get this far and we found at least one CreateModel or AddField
# migration, the migration is considered implicitly applied.
# If we get this far and we found at least one CreateModel or AddField migration,
# the migration is considered implicitly applied.
return (found_create_model_migration or found_add_field_migration), after_state
@@ -11,7 +11,6 @@ class Node:
A single node in the migration graph. Contains direct links to adjacent
nodes in either direction.
"""
def __init__(self, key):
self.key = key
self.children = set()
@@ -33,7 +32,7 @@ class Node:
return str(self.key)
def __repr__(self):
return "<%s: (%r, %r)>" % (self.__class__.__name__, self.key[0], self.key[1])
return '<%s: (%r, %r)>' % (self.__class__.__name__, self.key[0], self.key[1])
def add_child(self, child):
self.children.add(child)
@@ -50,7 +49,6 @@ class DummyNode(Node):
After the migration graph is processed, all dummy nodes should be removed.
If there are any left, a nonexistent dependency error is raised.
"""
def __init__(self, key, origin, error_message):
super().__init__(key)
self.origin = origin
@@ -102,7 +100,7 @@ class MigrationGraph:
"""
This may create dummy nodes if they don't yet exist. If
`skip_validation=True`, validate_consistency() should be called
afterward.
afterwards.
"""
if child not in self.nodes:
error_message = (
@@ -135,7 +133,7 @@ class MigrationGraph:
raise NodeNotFoundError(
"Unable to find replacement node %r. It was either never added"
" to the migration graph, or has been removed." % (replacement,),
replacement,
replacement
) from err
for replaced_key in replaced:
self.nodes.pop(replaced_key, None)
@@ -169,9 +167,8 @@ class MigrationGraph:
except KeyError as err:
raise NodeNotFoundError(
"Unable to remove replacement node %r. It was either never added"
" to the migration graph, or has been removed already."
% (replacement,),
replacement,
" to the migration graph, or has been removed already." % (replacement,),
replacement
) from err
replaced_nodes = set()
replaced_nodes_parents = set()
@@ -231,10 +228,7 @@ class MigrationGraph:
visited.append(node.key)
else:
stack.append((node, True))
stack += [
(n, False)
for n in sorted(node.parents if forwards else node.children)
]
stack += [(n, False) for n in sorted(node.parents if forwards else node.children)]
return visited
def root_nodes(self, app=None):
@@ -244,9 +238,7 @@ class MigrationGraph:
"""
roots = set()
for node in self.nodes:
if all(key[0] != node[0] for key in self.node_map[node].parents) and (
not app or app == node[0]
):
if all(key[0] != node[0] for key in self.node_map[node].parents) and (not app or app == node[0]):
roots.add(node)
return sorted(roots)
@@ -260,9 +252,7 @@ class MigrationGraph:
"""
leaves = set()
for node in self.nodes:
if all(key[0] != node[0] for key in self.node_map[node].children) and (
not app or app == node[0]
):
if all(key[0] != node[0] for key in self.node_map[node].children) and (not app or app == node[0]):
leaves.add(node)
return sorted(leaves)
@@ -280,10 +270,8 @@ class MigrationGraph:
# hashing.
node = child.key
if node in stack:
cycle = stack[stack.index(node) :]
raise CircularDependencyError(
", ".join("%s.%s" % n for n in cycle)
)
cycle = stack[stack.index(node):]
raise CircularDependencyError(", ".join("%s.%s" % n for n in cycle))
if node in todo:
stack.append(node)
todo.remove(node)
@@ -292,16 +280,14 @@ class MigrationGraph:
node = stack.pop()
def __str__(self):
return "Graph: %s nodes, %s edges" % self._nodes_and_edges()
return 'Graph: %s nodes, %s edges' % self._nodes_and_edges()
def __repr__(self):
nodes, edges = self._nodes_and_edges()
return "<%s: nodes=%s, edges=%s>" % (self.__class__.__name__, nodes, edges)
return '<%s: nodes=%s, edges=%s>' % (self.__class__.__name__, nodes, edges)
def _nodes_and_edges(self):
return len(self.nodes), sum(
len(node.parents) for node in self.node_map.values()
)
return len(self.nodes), sum(len(node.parents) for node in self.node_map.values())
def _generate_plan(self, nodes, at_end):
plan = []
@@ -8,13 +8,11 @@ from django.db.migrations.graph import MigrationGraph
from django.db.migrations.recorder import MigrationRecorder
from .exceptions import (
AmbiguityError,
BadMigrationError,
InconsistentMigrationHistory,
AmbiguityError, BadMigrationError, InconsistentMigrationHistory,
NodeNotFoundError,
)
MIGRATIONS_MODULE_NAME = "migrations"
MIGRATIONS_MODULE_NAME = 'migrations'
class MigrationLoader:
@@ -43,10 +41,7 @@ class MigrationLoader:
"""
def __init__(
self,
connection,
load=True,
ignore_no_migrations=False,
self, connection, load=True, ignore_no_migrations=False,
replace_migrations=True,
):
self.connection = connection
@@ -68,7 +63,7 @@ class MigrationLoader:
return settings.MIGRATION_MODULES[app_label], True
else:
app_package_name = apps.get_app_config(app_label).name
return "%s.%s" % (app_package_name, MIGRATIONS_MODULE_NAME), False
return '%s.%s' % (app_package_name, MIGRATIONS_MODULE_NAME), False
def load_disk(self):
"""Load the migrations from all INSTALLED_APPS from disk."""
@@ -85,22 +80,24 @@ class MigrationLoader:
try:
module = import_module(module_name)
except ModuleNotFoundError as e:
if (explicit and self.ignore_no_migrations) or (
not explicit and MIGRATIONS_MODULE_NAME in e.name.split(".")
if (
(explicit and self.ignore_no_migrations) or
(not explicit and MIGRATIONS_MODULE_NAME in e.name.split('.'))
):
self.unmigrated_apps.add(app_config.label)
continue
raise
else:
# Module is not a package (e.g. migrations.py).
if not hasattr(module, "__path__"):
if not hasattr(module, '__path__'):
self.unmigrated_apps.add(app_config.label)
continue
# Empty directories are namespaces. Namespace packages have no
# __file__ and don't use a list for __path__. See
# https://docs.python.org/3/reference/import.html#namespace-packages
if getattr(module, "__file__", None) is None and not isinstance(
module.__path__, list
if (
getattr(module, '__file__', None) is None and
not isinstance(module.__path__, list)
):
self.unmigrated_apps.add(app_config.label)
continue
@@ -109,17 +106,16 @@ class MigrationLoader:
reload(module)
self.migrated_apps.add(app_config.label)
migration_names = {
name
for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
if not is_pkg and name[0] not in "_~"
name for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
if not is_pkg and name[0] not in '_~'
}
# Load migrations
for migration_name in migration_names:
migration_path = "%s.%s" % (module_name, migration_name)
migration_path = '%s.%s' % (module_name, migration_name)
try:
migration_module = import_module(migration_path)
except ImportError as e:
if "bad magic number" in str(e):
if 'bad magic number' in str(e):
raise ImportError(
"Couldn't import %r as it appears to be a stale "
".pyc file." % migration_path
@@ -128,12 +124,9 @@ class MigrationLoader:
raise
if not hasattr(migration_module, "Migration"):
raise BadMigrationError(
"Migration %s in app %s has no Migration class"
% (migration_name, app_config.label)
"Migration %s in app %s has no Migration class" % (migration_name, app_config.label)
)
self.disk_migrations[
app_config.label, migration_name
] = migration_module.Migration(
self.disk_migrations[app_config.label, migration_name] = migration_module.Migration(
migration_name,
app_config.label,
)
@@ -149,20 +142,14 @@ class MigrationLoader:
# Do the search
results = []
for migration_app_label, migration_name in self.disk_migrations:
if migration_app_label == app_label and migration_name.startswith(
name_prefix
):
if migration_app_label == app_label and migration_name.startswith(name_prefix):
results.append((migration_app_label, migration_name))
if len(results) > 1:
raise AmbiguityError(
"There is more than one migration for '%s' with the prefix '%s'"
% (app_label, name_prefix)
"There is more than one migration for '%s' with the prefix '%s'" % (app_label, name_prefix)
)
elif not results:
raise KeyError(
f"There is no migration for '{app_label}' with the prefix "
f"'{name_prefix}'"
)
raise KeyError("There no migrations for '%s' with the prefix '%s'" % (app_label, name_prefix))
else:
return self.disk_migrations[results[0]]
@@ -191,9 +178,7 @@ class MigrationLoader:
if self.ignore_no_migrations:
return None
else:
raise ValueError(
"Dependency on app with no migrations: %s" % key[0]
)
raise ValueError("Dependency on app with no migrations: %s" % key[0])
raise ValueError("Dependency on unknown app: %s" % key[0])
def add_internal_dependencies(self, key, migration):
@@ -203,7 +188,7 @@ class MigrationLoader:
"""
for parent in migration.dependencies:
# Ignore __first__ references to the same app.
if parent[0] == key[0] and parent[1] != "__first__":
if parent[0] == key[0] and parent[1] != '__first__':
self.graph.add_dependency(migration, key, parent, skip_validation=True)
def add_external_dependencies(self, key, migration):
@@ -253,9 +238,7 @@ class MigrationLoader:
for key, migration in self.replacements.items():
# Get applied status of each of this migration's replacement
# targets.
applied_statuses = [
(target in self.applied_migrations) for target in migration.replaces
]
applied_statuses = [(target in self.applied_migrations) for target in migration.replaces]
# The replacing migration is only marked as applied if all of
# its replacement targets are.
if all(applied_statuses):
@@ -287,11 +270,9 @@ class MigrationLoader:
# Try to reraise exception with more detail.
if exc.node in reverse_replacements:
candidates = reverse_replacements.get(exc.node, set())
is_replaced = any(
candidate in self.graph.nodes for candidate in candidates
)
is_replaced = any(candidate in self.graph.nodes for candidate in candidates)
if not is_replaced:
tries = ", ".join("%s.%s" % c for c in candidates)
tries = ', '.join('%s.%s' % c for c in candidates)
raise NodeNotFoundError(
"Migration {0} depends on nonexistent node ('{1}', '{2}'). "
"Django tried to replace migration {1}.{2} with any of [{3}] "
@@ -299,7 +280,7 @@ class MigrationLoader:
"are already applied.".format(
exc.origin, exc.node[0], exc.node[1], tries
),
exc.node,
exc.node
) from exc
raise
self.graph.ensure_not_cyclic()
@@ -320,17 +301,12 @@ class MigrationLoader:
# Skip unapplied squashed migrations that have all of their
# `replaces` applied.
if parent in self.replacements:
if all(
m in applied for m in self.replacements[parent].replaces
):
if all(m in applied for m in self.replacements[parent].replaces):
continue
raise InconsistentMigrationHistory(
"Migration {}.{} is applied before its dependency "
"{}.{} on database '{}'.".format(
migration[0],
migration[1],
parent[0],
parent[1],
migration[0], migration[1], parent[0], parent[1],
connection.alias,
)
)
@@ -347,9 +323,7 @@ class MigrationLoader:
if app_label in seen_apps:
conflicting_apps.add(app_label)
seen_apps.setdefault(app_label, set()).add(migration_name)
return {
app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps
}
return {app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps}
def project_state(self, nodes=None, at_end=True):
"""
@@ -358,9 +332,7 @@ class MigrationLoader:
See graph.make_state() for the meaning of "nodes" and "at_end".
"""
return self.graph.make_state(
nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps
)
return self.graph.make_state(nodes=nodes, at_end=at_end, real_apps=list(self.unmigrated_apps))
def collect_sql(self, plan):
"""
@@ -370,13 +342,9 @@ class MigrationLoader:
statements = []
state = None
for migration, backwards in plan:
with self.connection.schema_editor(
collect_sql=True, atomic=migration.atomic
) as schema_editor:
with self.connection.schema_editor(collect_sql=True, atomic=migration.atomic) as schema_editor:
if state is None:
state = self.project_state(
(migration.app_label, migration.name), at_end=False
)
state = self.project_state((migration.app_label, migration.name), at_end=False)
if not backwards:
state = migration.apply(state, schema_editor, collect_sql=True)
else:
@@ -1,3 +1,4 @@
from django.db.migrations import operations
from django.db.migrations.utils import get_migration_name_timestamp
from django.db.transaction import atomic
@@ -12,8 +13,7 @@ class Migration:
and subclass it as a class called Migration. It will have one or more
of the following attributes:
- operations: A list of Operation instances, probably from
django.db.migrations.operations
- operations: A list of Operation instances, probably from django.db.migrations.operations
- dependencies: A list of tuples of (app_path, migration_name)
- run_before: A list of tuples of (app_path, migration_name)
- replaces: A list of migration_names
@@ -61,9 +61,9 @@ class Migration:
def __eq__(self, other):
return (
isinstance(other, Migration)
and self.name == other.name
and self.app_label == other.app_label
isinstance(other, Migration) and
self.name == other.name and
self.app_label == other.app_label
)
def __repr__(self):
@@ -105,8 +105,7 @@ class Migration:
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
schema_editor.collected_sql.append(
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS "
"SQL:"
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS SQL:"
)
schema_editor.collected_sql.append("-- %s" % operation.describe())
schema_editor.collected_sql.append("--")
@@ -116,21 +115,15 @@ class Migration:
old_state = project_state.clone()
operation.state_forwards(self.app_label, project_state)
# Run the operation
atomic_operation = operation.atomic or (
self.atomic and operation.atomic is not False
)
atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
if not schema_editor.atomic_migration and atomic_operation:
# Force a transaction on a non-transactional-DDL backend or an
# atomic operation inside a non-atomic migration.
with atomic(schema_editor.connection.alias):
operation.database_forwards(
self.app_label, schema_editor, old_state, project_state
)
operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
else:
# Normal behaviour
operation.database_forwards(
self.app_label, schema_editor, old_state, project_state
)
operation.database_forwards(self.app_label, schema_editor, old_state, project_state)
return project_state
def unapply(self, project_state, schema_editor, collect_sql=False):
@@ -153,9 +146,7 @@ class Migration:
for operation in self.operations:
# If it's irreversible, error out
if not operation.reversible:
raise IrreversibleError(
"Operation %s in %s is not reversible" % (operation, self)
)
raise IrreversibleError("Operation %s in %s is not reversible" % (operation, self))
# Preserve new state from previous run to not tamper the same state
# over all operations
new_state = new_state.clone()
@@ -169,28 +160,21 @@ class Migration:
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
schema_editor.collected_sql.append(
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS "
"SQL:"
"-- MIGRATION NOW PERFORMS OPERATION THAT CANNOT BE WRITTEN AS SQL:"
)
schema_editor.collected_sql.append("-- %s" % operation.describe())
schema_editor.collected_sql.append("--")
if not operation.reduces_to_sql:
continue
atomic_operation = operation.atomic or (
self.atomic and operation.atomic is not False
)
atomic_operation = operation.atomic or (self.atomic and operation.atomic is not False)
if not schema_editor.atomic_migration and atomic_operation:
# Force a transaction on a non-transactional-DDL backend or an
# atomic operation inside a non-atomic migration.
with atomic(schema_editor.connection.alias):
operation.database_backwards(
self.app_label, schema_editor, from_state, to_state
)
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
else:
# Normal behaviour
operation.database_backwards(
self.app_label, schema_editor, from_state, to_state
)
operation.database_backwards(self.app_label, schema_editor, from_state, to_state)
return project_state
def suggest_name(self):
@@ -199,22 +183,16 @@ class Migration:
are not guaranteed to be unique, but put some effort into the fallback
name to avoid VCS conflicts if possible.
"""
if self.initial:
return "initial"
raw_fragments = [op.migration_name_fragment for op in self.operations]
fragments = [name for name in raw_fragments if name]
if not fragments or len(fragments) != len(self.operations):
return "auto_%s" % get_migration_name_timestamp()
name = fragments[0]
for fragment in fragments[1:]:
new_name = f"{name}_{fragment}"
if len(new_name) > 52:
name = f"{name}_and_more"
break
name = new_name
name = None
if len(self.operations) == 1:
name = self.operations[0].migration_name_fragment
elif (
len(self.operations) > 1 and
all(isinstance(o, operations.CreateModel) for o in self.operations)
):
name = '_'.join(sorted(o.migration_name_fragment for o in self.operations))
if name is None:
name = 'initial' if self.initial else 'auto_%s' % get_migration_name_timestamp()
return name
@@ -1,40 +1,17 @@
from .fields import AddField, AlterField, RemoveField, RenameField
from .models import (
AddConstraint,
AddIndex,
AlterIndexTogether,
AlterModelManagers,
AlterModelOptions,
AlterModelTable,
AlterOrderWithRespectTo,
AlterUniqueTogether,
CreateModel,
DeleteModel,
RemoveConstraint,
RemoveIndex,
RenameModel,
AddConstraint, AddIndex, AlterIndexTogether, AlterModelManagers,
AlterModelOptions, AlterModelTable, AlterOrderWithRespectTo,
AlterUniqueTogether, CreateModel, DeleteModel, RemoveConstraint,
RemoveIndex, RenameModel,
)
from .special import RunPython, RunSQL, SeparateDatabaseAndState
__all__ = [
"CreateModel",
"DeleteModel",
"AlterModelTable",
"AlterUniqueTogether",
"RenameModel",
"AlterIndexTogether",
"AlterModelOptions",
"AddIndex",
"RemoveIndex",
"AddField",
"RemoveField",
"AlterField",
"RenameField",
"AddConstraint",
"RemoveConstraint",
"SeparateDatabaseAndState",
"RunSQL",
"RunPython",
"AlterOrderWithRespectTo",
"AlterModelManagers",
'CreateModel', 'DeleteModel', 'AlterModelTable', 'AlterUniqueTogether',
'RenameModel', 'AlterIndexTogether', 'AlterModelOptions', 'AddIndex',
'RemoveIndex', 'AddField', 'RemoveField', 'AlterField', 'RenameField',
'AddConstraint', 'RemoveConstraint',
'SeparateDatabaseAndState', 'RunSQL', 'RunPython',
'AlterOrderWithRespectTo', 'AlterModelManagers',
]
@@ -56,18 +56,14 @@ class Operation:
Take the state from the previous migration, and mutate it
so that it matches what this migration would perform.
"""
raise NotImplementedError(
"subclasses of Operation must provide a state_forwards() method"
)
raise NotImplementedError('subclasses of Operation must provide a state_forwards() method')
def database_forwards(self, app_label, schema_editor, from_state, to_state):
"""
Perform the mutation on the database schema in the normal
(forwards) direction.
"""
raise NotImplementedError(
"subclasses of Operation must provide a database_forwards() method"
)
raise NotImplementedError('subclasses of Operation must provide a database_forwards() method')
def database_backwards(self, app_label, schema_editor, from_state, to_state):
"""
@@ -75,9 +71,7 @@ class Operation:
direction - e.g. if this were CreateModel, it would in fact
drop the model's table.
"""
raise NotImplementedError(
"subclasses of Operation must provide a database_backwards() method"
)
raise NotImplementedError('subclasses of Operation must provide a database_backwards() method')
def describe(self):
"""
@@ -1,8 +1,9 @@
from django.db.migrations.utils import field_references
from django.core.exceptions import FieldDoesNotExist
from django.db.models import NOT_PROVIDED
from django.utils.functional import cached_property
from .base import Operation
from .utils import field_is_referenced, field_references, get_references
class FieldOperation(Operation):
@@ -23,23 +24,16 @@ class FieldOperation(Operation):
return self.model_name_lower == operation.model_name_lower
def is_same_field_operation(self, operation):
return (
self.is_same_model_operation(operation)
and self.name_lower == operation.name_lower
)
return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower
def references_model(self, name, app_label):
name_lower = name.lower()
if name_lower == self.model_name_lower:
return True
if self.field:
return bool(
field_references(
(app_label, self.model_name_lower),
self.field,
(app_label, name_lower),
)
)
return bool(field_references(
(app_label, self.model_name_lower), self.field, (app_label, name_lower)
))
return False
def references_field(self, model_name, name, app_label):
@@ -48,27 +42,22 @@ class FieldOperation(Operation):
if model_name_lower == self.model_name_lower:
if name == self.name:
return True
elif (
self.field
and hasattr(self.field, "from_fields")
and name in self.field.from_fields
):
elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields:
return True
# Check if this operation remotely references the field.
if self.field is None:
return False
return bool(
field_references(
(app_label, self.model_name_lower),
self.field,
(app_label, model_name_lower),
name,
)
)
return bool(field_references(
(app_label, self.model_name_lower),
self.field,
(app_label, model_name_lower),
name,
))
def reduce(self, operation, app_label):
return super().reduce(operation, app_label) or not operation.references_field(
self.model_name, self.name, app_label
return (
super().reduce(operation, app_label) or
not operation.references_field(self.model_name, self.name, app_label)
)
@@ -81,22 +70,29 @@ class AddField(FieldOperation):
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"name": self.name,
"field": self.field,
'model_name': self.model_name,
'name': self.name,
'field': self.field,
}
if self.preserve_default is not True:
kwargs["preserve_default"] = self.preserve_default
return (self.__class__.__name__, [], kwargs)
kwargs['preserve_default'] = self.preserve_default
return (
self.__class__.__name__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.add_field(
app_label,
self.model_name_lower,
self.name,
self.field,
self.preserve_default,
)
# If preserve default is off, don't use the default for future state
if not self.preserve_default:
field = self.field.clone()
field.default = NOT_PROVIDED
else:
field = self.field
state.models[app_label, self.model_name_lower].fields[self.name] = field
# Delay rendering of relationships if it's not a relational field
delay = not field.is_relation
state.reload_model(app_label, self.model_name_lower, delay=delay)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
@@ -115,21 +111,17 @@ class AddField(FieldOperation):
def database_backwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
schema_editor.remove_field(
from_model, from_model._meta.get_field(self.name)
)
schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
def describe(self):
return "Add field %s to %s" % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return "%s_%s" % (self.model_name_lower, self.name_lower)
return '%s_%s' % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
if isinstance(operation, FieldOperation) and self.is_same_field_operation(
operation
):
if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation):
if isinstance(operation, AlterField):
return [
AddField(
@@ -156,20 +148,26 @@ class RemoveField(FieldOperation):
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"name": self.name,
'model_name': self.model_name,
'name': self.name,
}
return (self.__class__.__name__, [], kwargs)
return (
self.__class__.__name__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.remove_field(app_label, self.model_name_lower, self.name)
model_state = state.models[app_label, self.model_name_lower]
old_field = model_state.fields.pop(self.name)
# Delay rendering of relationships if it's not a relational field
delay = not old_field.is_relation
state.reload_model(app_label, self.model_name_lower, delay=delay)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
from_model = from_state.apps.get_model(app_label, self.model_name)
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
schema_editor.remove_field(
from_model, from_model._meta.get_field(self.name)
)
schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
def database_backwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
@@ -182,15 +180,11 @@ class RemoveField(FieldOperation):
@property
def migration_name_fragment(self):
return "remove_%s_%s" % (self.model_name_lower, self.name_lower)
return 'remove_%s_%s' % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
from .models import DeleteModel
if (
isinstance(operation, DeleteModel)
and operation.name_lower == self.model_name_lower
):
if isinstance(operation, DeleteModel) and operation.name_lower == self.model_name_lower:
return [operation]
return super().reduce(operation, app_label)
@@ -207,22 +201,37 @@ class AlterField(FieldOperation):
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"name": self.name,
"field": self.field,
'model_name': self.model_name,
'name': self.name,
'field': self.field,
}
if self.preserve_default is not True:
kwargs["preserve_default"] = self.preserve_default
return (self.__class__.__name__, [], kwargs)
kwargs['preserve_default'] = self.preserve_default
return (
self.__class__.__name__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_field(
app_label,
self.model_name_lower,
self.name,
self.field,
self.preserve_default,
if not self.preserve_default:
field = self.field.clone()
field.default = NOT_PROVIDED
else:
field = self.field
model_state = state.models[app_label, self.model_name_lower]
model_state.fields[self.name] = field
# TODO: investigate if old relational fields must be reloaded or if it's
# sufficient if the new field is (#27737).
# Delay rendering of relationships if it's not a relational field and
# not referenced by a foreign key.
delay = (
not field.is_relation and
not field_is_referenced(
state, (app_label, self.model_name_lower), (self.name, field),
)
)
state.reload_model(app_label, self.model_name_lower, delay=delay)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
@@ -244,16 +253,12 @@ class AlterField(FieldOperation):
@property
def migration_name_fragment(self):
return "alter_%s_%s" % (self.model_name_lower, self.name_lower)
return 'alter_%s_%s' % (self.model_name_lower, self.name_lower)
def reduce(self, operation, app_label):
if isinstance(operation, RemoveField) and self.is_same_field_operation(
operation
):
if isinstance(operation, RemoveField) and self.is_same_field_operation(operation):
return [operation]
elif isinstance(operation, RenameField) and self.is_same_field_operation(
operation
):
elif isinstance(operation, RenameField) and self.is_same_field_operation(operation):
return [
operation,
AlterField(
@@ -283,16 +288,60 @@ class RenameField(FieldOperation):
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"old_name": self.old_name,
"new_name": self.new_name,
'model_name': self.model_name,
'old_name': self.old_name,
'new_name': self.new_name,
}
return (self.__class__.__name__, [], kwargs)
return (
self.__class__.__name__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.rename_field(
app_label, self.model_name_lower, self.old_name, self.new_name
model_state = state.models[app_label, self.model_name_lower]
# Rename the field
fields = model_state.fields
try:
found = fields.pop(self.old_name)
except KeyError:
raise FieldDoesNotExist(
"%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
)
fields[self.new_name] = found
for field in fields.values():
# Fix from_fields to refer to the new field.
from_fields = getattr(field, 'from_fields', None)
if from_fields:
field.from_fields = tuple([
self.new_name if from_field_name == self.old_name else from_field_name
for from_field_name in from_fields
])
# Fix index/unique_together to refer to the new field
options = model_state.options
for option in ('index_together', 'unique_together'):
if option in options:
options[option] = [
[self.new_name if n == self.old_name else n for n in together]
for together in options[option]
]
# Fix to_fields to refer to the new field.
delay = True
references = get_references(
state, (app_label, self.model_name_lower), (self.old_name, found),
)
for *_, field, reference in references:
delay = False
if reference.to:
remote_field, to_fields = reference.to
if getattr(remote_field, 'field_name', None) == self.old_name:
remote_field.field_name = self.new_name
if to_fields:
field.to_fields = tuple([
self.new_name if to_field_name == self.old_name else to_field_name
for to_field_name in to_fields
])
state.reload_model(app_label, self.model_name_lower, delay=delay)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.model_name)
@@ -315,15 +364,11 @@ class RenameField(FieldOperation):
)
def describe(self):
return "Rename field %s on %s to %s" % (
self.old_name,
self.model_name,
self.new_name,
)
return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
@property
def migration_name_fragment(self):
return "rename_%s_%s_%s" % (
return 'rename_%s_%s_%s' % (
self.old_name_lower,
self.model_name_lower,
self.new_name_lower,
@@ -331,15 +376,14 @@ class RenameField(FieldOperation):
def references_field(self, model_name, name, app_label):
return self.references_model(model_name, app_label) and (
name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
name.lower() == self.old_name_lower or
name.lower() == self.new_name_lower
)
def reduce(self, operation, app_label):
if (
isinstance(operation, RenameField)
and self.is_same_model_operation(operation)
and self.new_name_lower == operation.old_name_lower
):
if (isinstance(operation, RenameField) and
self.is_same_model_operation(operation) and
self.new_name_lower == operation.old_name_lower):
return [
RenameField(
self.model_name,
@@ -348,8 +392,8 @@ class RenameField(FieldOperation):
),
]
# Skip `FieldOperation.reduce` as we want to run `references_field`
# against self.old_name and self.new_name.
return super(FieldOperation, self).reduce(operation, app_label) or not (
operation.references_field(self.model_name, self.old_name, app_label)
or operation.references_field(self.model_name, self.new_name, app_label)
# against self.new_name.
return (
super(FieldOperation, self).reduce(operation, app_label) or
not operation.references_field(self.model_name, self.new_name, app_label)
)
@@ -1,11 +1,13 @@
from django.db import models
from django.db.migrations.operations.base import Operation
from django.db.migrations.state import ModelState
from django.db.migrations.utils import field_references, resolve_relation
from django.db.models.options import normalize_together
from django.utils.functional import cached_property
from .fields import AddField, AlterField, FieldOperation, RemoveField, RenameField
from .fields import (
AddField, AlterField, FieldOperation, RemoveField, RenameField,
)
from .utils import field_references, get_references, resolve_relation
def _check_for_duplicates(arg_name, objs):
@@ -30,15 +32,16 @@ class ModelOperation(Operation):
return name.lower() == self.name_lower
def reduce(self, operation, app_label):
return super().reduce(operation, app_label) or not operation.references_model(
self.name, app_label
return (
super().reduce(operation, app_label) or
not operation.references_model(self.name, app_label)
)
class CreateModel(ModelOperation):
"""Create a model's table."""
serialization_expand_args = ["fields", "options", "managers"]
serialization_expand_args = ['fields', 'options', 'managers']
def __init__(self, name, fields, options=None, bases=None, managers=None):
self.fields = fields
@@ -48,44 +51,40 @@ class CreateModel(ModelOperation):
super().__init__(name)
# Sanity-check that there are no duplicated field names, bases, or
# manager names
_check_for_duplicates("fields", (name for name, _ in self.fields))
_check_for_duplicates(
"bases",
(
base._meta.label_lower
if hasattr(base, "_meta")
else base.lower()
if isinstance(base, str)
else base
for base in self.bases
),
)
_check_for_duplicates("managers", (name for name, _ in self.managers))
_check_for_duplicates('fields', (name for name, _ in self.fields))
_check_for_duplicates('bases', (
base._meta.label_lower if hasattr(base, '_meta') else
base.lower() if isinstance(base, str) else base
for base in self.bases
))
_check_for_duplicates('managers', (name for name, _ in self.managers))
def deconstruct(self):
kwargs = {
"name": self.name,
"fields": self.fields,
'name': self.name,
'fields': self.fields,
}
if self.options:
kwargs["options"] = self.options
kwargs['options'] = self.options
if self.bases and self.bases != (models.Model,):
kwargs["bases"] = self.bases
if self.managers and self.managers != [("objects", models.Manager())]:
kwargs["managers"] = self.managers
return (self.__class__.__qualname__, [], kwargs)
kwargs['bases'] = self.bases
if self.managers and self.managers != [('objects', models.Manager())]:
kwargs['managers'] = self.managers
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.add_model(
ModelState(
app_label,
self.name,
list(self.fields),
dict(self.options),
tuple(self.bases),
list(self.managers),
)
)
state.add_model(ModelState(
app_label,
self.name,
list(self.fields),
dict(self.options),
tuple(self.bases),
list(self.managers),
))
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.name)
@@ -98,10 +97,7 @@ class CreateModel(ModelOperation):
schema_editor.delete_model(model)
def describe(self):
return "Create %smodel %s" % (
"proxy " if self.options.get("proxy", False) else "",
self.name,
)
return "Create %smodel %s" % ("proxy " if self.options.get("proxy", False) else "", self.name)
@property
def migration_name_fragment(self):
@@ -115,32 +111,22 @@ class CreateModel(ModelOperation):
# Check we didn't inherit from the model
reference_model_tuple = (app_label, name_lower)
for base in self.bases:
if (
base is not models.Model
and isinstance(base, (models.base.ModelBase, str))
and resolve_relation(base, app_label) == reference_model_tuple
):
if (base is not models.Model and isinstance(base, (models.base.ModelBase, str)) and
resolve_relation(base, app_label) == reference_model_tuple):
return True
# Check we have no FKs/M2Ms with it
for _name, field in self.fields:
if field_references(
(app_label, self.name_lower), field, reference_model_tuple
):
if field_references((app_label, self.name_lower), field, reference_model_tuple):
return True
return False
def reduce(self, operation, app_label):
if (
isinstance(operation, DeleteModel)
and self.name_lower == operation.name_lower
and not self.options.get("proxy", False)
):
if (isinstance(operation, DeleteModel) and
self.name_lower == operation.name_lower and
not self.options.get("proxy", False)):
return []
elif (
isinstance(operation, RenameModel)
and self.name_lower == operation.old_name_lower
):
elif isinstance(operation, RenameModel) and self.name_lower == operation.old_name_lower:
return [
CreateModel(
operation.new_name,
@@ -150,10 +136,7 @@ class CreateModel(ModelOperation):
managers=self.managers,
),
]
elif (
isinstance(operation, AlterModelOptions)
and self.name_lower == operation.name_lower
):
elif isinstance(operation, AlterModelOptions) and self.name_lower == operation.name_lower:
options = {**self.options, **operation.options}
for key in operation.ALTER_OPTION_KEYS:
if key not in operation.options:
@@ -167,42 +150,27 @@ class CreateModel(ModelOperation):
managers=self.managers,
),
]
elif (
isinstance(operation, AlterTogetherOptionOperation)
and self.name_lower == operation.name_lower
):
elif isinstance(operation, AlterTogetherOptionOperation) and self.name_lower == operation.name_lower:
return [
CreateModel(
self.name,
fields=self.fields,
options={
**self.options,
**{operation.option_name: operation.option_value},
},
options={**self.options, **{operation.option_name: operation.option_value}},
bases=self.bases,
managers=self.managers,
),
]
elif (
isinstance(operation, AlterOrderWithRespectTo)
and self.name_lower == operation.name_lower
):
elif isinstance(operation, AlterOrderWithRespectTo) and self.name_lower == operation.name_lower:
return [
CreateModel(
self.name,
fields=self.fields,
options={
**self.options,
"order_with_respect_to": operation.order_with_respect_to,
},
options={**self.options, 'order_with_respect_to': operation.order_with_respect_to},
bases=self.bases,
managers=self.managers,
),
]
elif (
isinstance(operation, FieldOperation)
and self.name_lower == operation.model_name_lower
):
elif isinstance(operation, FieldOperation) and self.name_lower == operation.model_name_lower:
if isinstance(operation, AddField):
return [
CreateModel(
@@ -228,25 +196,17 @@ class CreateModel(ModelOperation):
]
elif isinstance(operation, RemoveField):
options = self.options.copy()
for option_name in ("unique_together", "index_together"):
for option_name in ('unique_together', 'index_together'):
option = options.pop(option_name, None)
if option:
option = set(
filter(
bool,
(
tuple(
f for f in fields if f != operation.name_lower
)
for fields in option
),
)
)
option = set(filter(bool, (
tuple(f for f in fields if f != operation.name_lower) for fields in option
)))
if option:
options[option_name] = option
order_with_respect_to = options.get("order_with_respect_to")
order_with_respect_to = options.get('order_with_respect_to')
if order_with_respect_to == operation.name_lower:
del options["order_with_respect_to"]
del options['order_with_respect_to']
return [
CreateModel(
self.name,
@@ -262,19 +222,16 @@ class CreateModel(ModelOperation):
]
elif isinstance(operation, RenameField):
options = self.options.copy()
for option_name in ("unique_together", "index_together"):
for option_name in ('unique_together', 'index_together'):
option = options.get(option_name)
if option:
options[option_name] = {
tuple(
operation.new_name if f == operation.old_name else f
for f in fields
)
tuple(operation.new_name if f == operation.old_name else f for f in fields)
for fields in option
}
order_with_respect_to = options.get("order_with_respect_to")
order_with_respect_to = options.get('order_with_respect_to')
if order_with_respect_to == operation.old_name:
options["order_with_respect_to"] = operation.new_name
options['order_with_respect_to'] = operation.new_name
return [
CreateModel(
self.name,
@@ -295,9 +252,13 @@ class DeleteModel(ModelOperation):
def deconstruct(self):
kwargs = {
"name": self.name,
'name': self.name,
}
return (self.__class__.__qualname__, [], kwargs)
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.remove_model(app_label, self.name_lower)
@@ -322,7 +283,7 @@ class DeleteModel(ModelOperation):
@property
def migration_name_fragment(self):
return "delete_%s" % self.name_lower
return 'delete_%s' % self.name_lower
class RenameModel(ModelOperation):
@@ -343,13 +304,41 @@ class RenameModel(ModelOperation):
def deconstruct(self):
kwargs = {
"old_name": self.old_name,
"new_name": self.new_name,
'old_name': self.old_name,
'new_name': self.new_name,
}
return (self.__class__.__qualname__, [], kwargs)
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.rename_model(app_label, self.old_name, self.new_name)
# Add a new model.
renamed_model = state.models[app_label, self.old_name_lower].clone()
renamed_model.name = self.new_name
state.models[app_label, self.new_name_lower] = renamed_model
# Repoint all fields pointing to the old model to the new one.
old_model_tuple = (app_label, self.old_name_lower)
new_remote_model = '%s.%s' % (app_label, self.new_name)
to_reload = set()
for model_state, name, field, reference in get_references(state, old_model_tuple):
changed_field = None
if reference.to:
changed_field = field.clone()
changed_field.remote_field.model = new_remote_model
if reference.through:
if changed_field is None:
changed_field = field.clone()
changed_field.remote_field.through = new_remote_model
if changed_field:
model_state.fields[name] = changed_field
to_reload.add((model_state.app_label, model_state.name_lower))
# Reload models related to old model before removing the old model.
state.reload_models(to_reload, delay=True)
# Remove the old model.
state.remove_model(app_label, self.old_name_lower)
state.reload_model(app_label, self.new_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.new_name)
@@ -372,24 +361,19 @@ class RenameModel(ModelOperation):
related_object.related_model._meta.app_label,
related_object.related_model._meta.model_name,
)
to_field = to_state.apps.get_model(*related_key)._meta.get_field(
related_object.field.name
)
to_field = to_state.apps.get_model(
*related_key
)._meta.get_field(related_object.field.name)
schema_editor.alter_field(
model,
related_object.field,
to_field,
)
# Rename M2M fields whose name is based on this model's name.
fields = zip(
old_model._meta.local_many_to_many, new_model._meta.local_many_to_many
)
fields = zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many)
for (old_field, new_field) in fields:
# Skip self-referential fields as these are renamed above.
if (
new_field.model == new_field.related_model
or not new_field.remote_field.through._meta.auto_created
):
if new_field.model == new_field.related_model or not new_field.remote_field.through._meta.auto_created:
continue
# Rename the M2M table that's based on this model's name.
old_m2m_model = old_field.remote_field.through
@@ -408,23 +392,18 @@ class RenameModel(ModelOperation):
)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
self.new_name_lower, self.old_name_lower = (
self.old_name_lower,
self.new_name_lower,
)
self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
self.new_name, self.old_name = self.old_name, self.new_name
self.database_forwards(app_label, schema_editor, from_state, to_state)
self.new_name_lower, self.old_name_lower = (
self.old_name_lower,
self.new_name_lower,
)
self.new_name_lower, self.old_name_lower = self.old_name_lower, self.new_name_lower
self.new_name, self.old_name = self.old_name, self.new_name
def references_model(self, name, app_label):
return (
name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
name.lower() == self.old_name_lower or
name.lower() == self.new_name_lower
)
def describe(self):
@@ -432,13 +411,11 @@ class RenameModel(ModelOperation):
@property
def migration_name_fragment(self):
return "rename_%s_%s" % (self.old_name_lower, self.new_name_lower)
return 'rename_%s_%s' % (self.old_name_lower, self.new_name_lower)
def reduce(self, operation, app_label):
if (
isinstance(operation, RenameModel)
and self.new_name_lower == operation.old_name_lower
):
if (isinstance(operation, RenameModel) and
self.new_name_lower == operation.old_name_lower):
return [
RenameModel(
self.old_name,
@@ -447,17 +424,15 @@ class RenameModel(ModelOperation):
]
# Skip `ModelOperation.reduce` as we want to run `references_model`
# against self.new_name.
return super(ModelOperation, self).reduce(
operation, app_label
) or not operation.references_model(self.new_name, app_label)
return (
super(ModelOperation, self).reduce(operation, app_label) or
not operation.references_model(self.new_name, app_label)
)
class ModelOptionOperation(ModelOperation):
def reduce(self, operation, app_label):
if (
isinstance(operation, (self.__class__, DeleteModel))
and self.name_lower == operation.name_lower
):
if isinstance(operation, (self.__class__, DeleteModel)) and self.name_lower == operation.name_lower:
return [operation]
return super().reduce(operation, app_label)
@@ -471,13 +446,18 @@ class AlterModelTable(ModelOptionOperation):
def deconstruct(self):
kwargs = {
"name": self.name,
"table": self.table,
'name': self.name,
'table': self.table,
}
return (self.__class__.__qualname__, [], kwargs)
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_model_options(app_label, self.name_lower, {"db_table": self.table})
state.models[app_label, self.name_lower].options["db_table"] = self.table
state.reload_model(app_label, self.name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name)
@@ -489,9 +469,7 @@ class AlterModelTable(ModelOptionOperation):
new_model._meta.db_table,
)
# Rename M2M fields whose name is based on this model's db_table
for (old_field, new_field) in zip(
old_model._meta.local_many_to_many, new_model._meta.local_many_to_many
):
for (old_field, new_field) in zip(old_model._meta.local_many_to_many, new_model._meta.local_many_to_many):
if new_field.remote_field.through._meta.auto_created:
schema_editor.alter_db_table(
new_field.remote_field.through,
@@ -505,12 +483,12 @@ class AlterModelTable(ModelOptionOperation):
def describe(self):
return "Rename table for %s to %s" % (
self.name,
self.table if self.table is not None else "(default)",
self.table if self.table is not None else "(default)"
)
@property
def migration_name_fragment(self):
return "alter_%s_table" % self.name_lower
return 'alter_%s_table' % self.name_lower
class AlterTogetherOptionOperation(ModelOptionOperation):
@@ -528,23 +506,25 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
def deconstruct(self):
kwargs = {
"name": self.name,
'name': self.name,
self.option_name: self.option_value,
}
return (self.__class__.__qualname__, [], kwargs)
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_model_options(
app_label,
self.name_lower,
{self.option_name: self.option_value},
)
model_state = state.models[app_label, self.name_lower]
model_state.options[self.option_name] = self.option_value
state.reload_model(app_label, self.name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
new_model = to_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, new_model):
old_model = from_state.apps.get_model(app_label, self.name)
alter_together = getattr(schema_editor, "alter_%s" % self.option_name)
alter_together = getattr(schema_editor, 'alter_%s' % self.option_name)
alter_together(
new_model,
getattr(old_model._meta, self.option_name, set()),
@@ -555,21 +535,20 @@ class AlterTogetherOptionOperation(ModelOptionOperation):
return self.database_forwards(app_label, schema_editor, from_state, to_state)
def references_field(self, model_name, name, app_label):
return self.references_model(model_name, app_label) and (
not self.option_value
or any((name in fields) for fields in self.option_value)
return (
self.references_model(model_name, app_label) and
(
not self.option_value or
any((name in fields) for fields in self.option_value)
)
)
def describe(self):
return "Alter %s for %s (%s constraint(s))" % (
self.option_name,
self.name,
len(self.option_value or ""),
)
return "Alter %s for %s (%s constraint(s))" % (self.option_name, self.name, len(self.option_value or ''))
@property
def migration_name_fragment(self):
return "alter_%s_%s" % (self.name_lower, self.option_name)
return 'alter_%s_%s' % (self.name_lower, self.option_name)
class AlterUniqueTogether(AlterTogetherOptionOperation):
@@ -577,8 +556,7 @@ class AlterUniqueTogether(AlterTogetherOptionOperation):
Change the value of unique_together to the target one.
Input value of unique_together must be a set of tuples.
"""
option_name = "unique_together"
option_name = 'unique_together'
def __init__(self, name, unique_together):
super().__init__(name, unique_together)
@@ -589,7 +567,6 @@ class AlterIndexTogether(AlterTogetherOptionOperation):
Change the value of index_together to the target one.
Input value of index_together must be a set of tuples.
"""
option_name = "index_together"
def __init__(self, name, index_together):
@@ -599,7 +576,7 @@ class AlterIndexTogether(AlterTogetherOptionOperation):
class AlterOrderWithRespectTo(ModelOptionOperation):
"""Represent a change with the order_with_respect_to option."""
option_name = "order_with_respect_to"
option_name = 'order_with_respect_to'
def __init__(self, name, order_with_respect_to):
self.order_with_respect_to = order_with_respect_to
@@ -607,36 +584,30 @@ class AlterOrderWithRespectTo(ModelOptionOperation):
def deconstruct(self):
kwargs = {
"name": self.name,
"order_with_respect_to": self.order_with_respect_to,
'name': self.name,
'order_with_respect_to': self.order_with_respect_to,
}
return (self.__class__.__qualname__, [], kwargs)
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_model_options(
app_label,
self.name_lower,
{self.option_name: self.order_with_respect_to},
)
model_state = state.models[app_label, self.name_lower]
model_state.options['order_with_respect_to'] = self.order_with_respect_to
state.reload_model(app_label, self.name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
to_model = to_state.apps.get_model(app_label, self.name)
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
from_model = from_state.apps.get_model(app_label, self.name)
# Remove a field if we need to
if (
from_model._meta.order_with_respect_to
and not to_model._meta.order_with_respect_to
):
schema_editor.remove_field(
from_model, from_model._meta.get_field("_order")
)
if from_model._meta.order_with_respect_to and not to_model._meta.order_with_respect_to:
schema_editor.remove_field(from_model, from_model._meta.get_field("_order"))
# Add a field if we need to (altering the column is untouched as
# it's likely a rename)
elif (
to_model._meta.order_with_respect_to
and not from_model._meta.order_with_respect_to
):
elif to_model._meta.order_with_respect_to and not from_model._meta.order_with_respect_to:
field = to_model._meta.get_field("_order")
if not field.has_default():
field.default = 0
@@ -649,19 +620,20 @@ class AlterOrderWithRespectTo(ModelOptionOperation):
self.database_forwards(app_label, schema_editor, from_state, to_state)
def references_field(self, model_name, name, app_label):
return self.references_model(model_name, app_label) and (
self.order_with_respect_to is None or name == self.order_with_respect_to
return (
self.references_model(model_name, app_label) and
(
self.order_with_respect_to is None or
name == self.order_with_respect_to
)
)
def describe(self):
return "Set order_with_respect_to on %s to %s" % (
self.name,
self.order_with_respect_to,
)
return "Set order_with_respect_to on %s to %s" % (self.name, self.order_with_respect_to)
@property
def migration_name_fragment(self):
return "alter_%s_order_with_respect_to" % self.name_lower
return 'alter_%s_order_with_respect_to' % self.name_lower
class AlterModelOptions(ModelOptionOperation):
@@ -692,18 +664,22 @@ class AlterModelOptions(ModelOptionOperation):
def deconstruct(self):
kwargs = {
"name": self.name,
"options": self.options,
'name': self.name,
'options': self.options,
}
return (self.__class__.__qualname__, [], kwargs)
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
state.alter_model_options(
app_label,
self.name_lower,
self.options,
self.ALTER_OPTION_KEYS,
)
model_state = state.models[app_label, self.name_lower]
model_state.options = {**model_state.options, **self.options}
for key in self.ALTER_OPTION_KEYS:
if key not in self.options:
model_state.options.pop(key, False)
state.reload_model(app_label, self.name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass
@@ -716,23 +692,29 @@ class AlterModelOptions(ModelOptionOperation):
@property
def migration_name_fragment(self):
return "alter_%s_options" % self.name_lower
return 'alter_%s_options' % self.name_lower
class AlterModelManagers(ModelOptionOperation):
"""Alter the model's managers."""
serialization_expand_args = ["managers"]
serialization_expand_args = ['managers']
def __init__(self, name, managers):
self.managers = managers
super().__init__(name)
def deconstruct(self):
return (self.__class__.__qualname__, [self.name, self.managers], {})
return (
self.__class__.__qualname__,
[self.name, self.managers],
{}
)
def state_forwards(self, app_label, state):
state.alter_model_managers(app_label, self.name_lower, self.managers)
model_state = state.models[app_label, self.name_lower]
model_state.managers = list(self.managers)
state.reload_model(app_label, self.name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
pass
@@ -745,11 +727,11 @@ class AlterModelManagers(ModelOptionOperation):
@property
def migration_name_fragment(self):
return "alter_%s_managers" % self.name_lower
return 'alter_%s_managers' % self.name_lower
class IndexOperation(Operation):
option_name = "indexes"
option_name = 'indexes'
@cached_property
def model_name_lower(self):
@@ -769,7 +751,9 @@ class AddIndex(IndexOperation):
self.index = index
def state_forwards(self, app_label, state):
state.add_index(app_label, self.model_name_lower, self.index)
model_state = state.models[app_label, self.model_name_lower]
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.index.clone()]
state.reload_model(app_label, self.model_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
@@ -783,8 +767,8 @@ class AddIndex(IndexOperation):
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"index": self.index,
'model_name': self.model_name,
'index': self.index,
}
return (
self.__class__.__qualname__,
@@ -794,20 +778,20 @@ class AddIndex(IndexOperation):
def describe(self):
if self.index.expressions:
return "Create index %s on %s on model %s" % (
return 'Create index %s on %s on model %s' % (
self.index.name,
", ".join([str(expression) for expression in self.index.expressions]),
', '.join([str(expression) for expression in self.index.expressions]),
self.model_name,
)
return "Create index %s on field(s) %s of model %s" % (
return 'Create index %s on field(s) %s of model %s' % (
self.index.name,
", ".join(self.index.fields),
', '.join(self.index.fields),
self.model_name,
)
@property
def migration_name_fragment(self):
return "%s_%s" % (self.model_name_lower, self.index.name.lower())
return '%s_%s' % (self.model_name_lower, self.index.name.lower())
class RemoveIndex(IndexOperation):
@@ -818,7 +802,10 @@ class RemoveIndex(IndexOperation):
self.name = name
def state_forwards(self, app_label, state):
state.remove_index(app_label, self.model_name_lower, self.name)
model_state = state.models[app_label, self.model_name_lower]
indexes = model_state.options[self.option_name]
model_state.options[self.option_name] = [idx for idx in indexes if idx.name != self.name]
state.reload_model(app_label, self.model_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = from_state.apps.get_model(app_label, self.model_name)
@@ -836,8 +823,8 @@ class RemoveIndex(IndexOperation):
def deconstruct(self):
kwargs = {
"model_name": self.model_name,
"name": self.name,
'model_name': self.model_name,
'name': self.name,
}
return (
self.__class__.__qualname__,
@@ -846,22 +833,24 @@ class RemoveIndex(IndexOperation):
)
def describe(self):
return "Remove index %s from %s" % (self.name, self.model_name)
return 'Remove index %s from %s' % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return "remove_%s_%s" % (self.model_name_lower, self.name.lower())
return 'remove_%s_%s' % (self.model_name_lower, self.name.lower())
class AddConstraint(IndexOperation):
option_name = "constraints"
option_name = 'constraints'
def __init__(self, model_name, constraint):
self.model_name = model_name
self.constraint = constraint
def state_forwards(self, app_label, state):
state.add_constraint(app_label, self.model_name_lower, self.constraint)
model_state = state.models[app_label, self.model_name_lower]
model_state.options[self.option_name] = [*model_state.options[self.option_name], self.constraint]
state.reload_model(app_label, self.model_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
@@ -874,35 +863,31 @@ class AddConstraint(IndexOperation):
schema_editor.remove_constraint(model, self.constraint)
def deconstruct(self):
return (
self.__class__.__name__,
[],
{
"model_name": self.model_name,
"constraint": self.constraint,
},
)
return self.__class__.__name__, [], {
'model_name': self.model_name,
'constraint': self.constraint,
}
def describe(self):
return "Create constraint %s on model %s" % (
self.constraint.name,
self.model_name,
)
return 'Create constraint %s on model %s' % (self.constraint.name, self.model_name)
@property
def migration_name_fragment(self):
return "%s_%s" % (self.model_name_lower, self.constraint.name.lower())
return '%s_%s' % (self.model_name_lower, self.constraint.name.lower())
class RemoveConstraint(IndexOperation):
option_name = "constraints"
option_name = 'constraints'
def __init__(self, model_name, name):
self.model_name = model_name
self.name = name
def state_forwards(self, app_label, state):
state.remove_constraint(app_label, self.model_name_lower, self.name)
model_state = state.models[app_label, self.model_name_lower]
constraints = model_state.options[self.option_name]
model_state.options[self.option_name] = [c for c in constraints if c.name != self.name]
state.reload_model(app_label, self.model_name_lower, delay=True)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
model = to_state.apps.get_model(app_label, self.model_name)
@@ -919,18 +904,14 @@ class RemoveConstraint(IndexOperation):
schema_editor.add_constraint(model, constraint)
def deconstruct(self):
return (
self.__class__.__name__,
[],
{
"model_name": self.model_name,
"name": self.name,
},
)
return self.__class__.__name__, [], {
'model_name': self.model_name,
'name': self.name,
}
def describe(self):
return "Remove constraint %s from model %s" % (self.name, self.model_name)
return 'Remove constraint %s from model %s' % (self.name, self.model_name)
@property
def migration_name_fragment(self):
return "remove_%s_%s" % (self.model_name_lower, self.name.lower())
return 'remove_%s_%s' % (self.model_name_lower, self.name.lower())
@@ -11,7 +11,7 @@ class SeparateDatabaseAndState(Operation):
that affect the state or not the database, or so on.
"""
serialization_expand_args = ["database_operations", "state_operations"]
serialization_expand_args = ['database_operations', 'state_operations']
def __init__(self, database_operations=None, state_operations=None):
self.database_operations = database_operations or []
@@ -20,10 +20,14 @@ class SeparateDatabaseAndState(Operation):
def deconstruct(self):
kwargs = {}
if self.database_operations:
kwargs["database_operations"] = self.database_operations
kwargs['database_operations'] = self.database_operations
if self.state_operations:
kwargs["state_operations"] = self.state_operations
return (self.__class__.__qualname__, [], kwargs)
kwargs['state_operations'] = self.state_operations
return (
self.__class__.__qualname__,
[],
kwargs
)
def state_forwards(self, app_label, state):
for state_operation in self.state_operations:
@@ -34,9 +38,7 @@ class SeparateDatabaseAndState(Operation):
for database_operation in self.database_operations:
to_state = from_state.clone()
database_operation.state_forwards(app_label, to_state)
database_operation.database_forwards(
app_label, schema_editor, from_state, to_state
)
database_operation.database_forwards(app_label, schema_editor, from_state, to_state)
from_state = to_state
def database_backwards(self, app_label, schema_editor, from_state, to_state):
@@ -52,9 +54,7 @@ class SeparateDatabaseAndState(Operation):
for database_operation in reversed(self.database_operations):
from_state = to_state
to_state = to_states[database_operation]
database_operation.database_backwards(
app_label, schema_editor, from_state, to_state
)
database_operation.database_backwards(app_label, schema_editor, from_state, to_state)
def describe(self):
return "Custom state/database change combination"
@@ -67,12 +67,9 @@ class RunSQL(Operation):
Also accept a list of operations that represent the state change effected
by this SQL change, in case it's custom column/table creation/deletion.
"""
noop = ''
noop = ""
def __init__(
self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False
):
def __init__(self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False):
self.sql = sql
self.reverse_sql = reverse_sql
self.state_operations = state_operations or []
@@ -81,15 +78,19 @@ class RunSQL(Operation):
def deconstruct(self):
kwargs = {
"sql": self.sql,
'sql': self.sql,
}
if self.reverse_sql is not None:
kwargs["reverse_sql"] = self.reverse_sql
kwargs['reverse_sql'] = self.reverse_sql
if self.state_operations:
kwargs["state_operations"] = self.state_operations
kwargs['state_operations'] = self.state_operations
if self.hints:
kwargs["hints"] = self.hints
return (self.__class__.__qualname__, [], kwargs)
kwargs['hints'] = self.hints
return (
self.__class__.__qualname__,
[],
kwargs
)
@property
def reversible(self):
@@ -100,17 +101,13 @@ class RunSQL(Operation):
state_operation.state_forwards(app_label, state)
def database_forwards(self, app_label, schema_editor, from_state, to_state):
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
self._run_sql(schema_editor, self.sql)
def database_backwards(self, app_label, schema_editor, from_state, to_state):
if self.reverse_sql is None:
raise NotImplementedError("You cannot reverse this operation")
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
self._run_sql(schema_editor, self.reverse_sql)
def describe(self):
@@ -140,9 +137,7 @@ class RunPython(Operation):
reduces_to_sql = False
def __init__(
self, code, reverse_code=None, atomic=None, hints=None, elidable=False
):
def __init__(self, code, reverse_code=None, atomic=None, hints=None, elidable=False):
self.atomic = atomic
# Forwards code
if not callable(code):
@@ -160,15 +155,19 @@ class RunPython(Operation):
def deconstruct(self):
kwargs = {
"code": self.code,
'code': self.code,
}
if self.reverse_code is not None:
kwargs["reverse_code"] = self.reverse_code
kwargs['reverse_code'] = self.reverse_code
if self.atomic is not None:
kwargs["atomic"] = self.atomic
kwargs['atomic'] = self.atomic
if self.hints:
kwargs["hints"] = self.hints
return (self.__class__.__qualname__, [], kwargs)
kwargs['hints'] = self.hints
return (
self.__class__.__qualname__,
[],
kwargs
)
@property
def reversible(self):
@@ -183,9 +182,7 @@ class RunPython(Operation):
# RunPython has access to all models. Ensure that all models are
# reloaded in case any are delayed.
from_state.clear_delayed_apps_cache()
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
# We now execute the Python code in a context that contains a 'models'
# object, representing the versioned models as an app registry.
# We could try to override the global cache, but then people will still
@@ -195,9 +192,7 @@ class RunPython(Operation):
def database_backwards(self, app_label, schema_editor, from_state, to_state):
if self.reverse_code is None:
raise NotImplementedError("You cannot reverse this operation")
if router.allow_migrate(
schema_editor.connection.alias, app_label, **self.hints
):
if router.allow_migrate(schema_editor.connection.alias, app_label, **self.hints):
self.reverse_code(from_state.apps, schema_editor)
def describe(self):
@@ -0,0 +1,102 @@
from collections import namedtuple
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
def resolve_relation(model, app_label=None, model_name=None):
"""
Turn a model class or model reference string and return a model tuple.
app_label and model_name are used to resolve the scope of recursive and
unscoped model relationship.
"""
if isinstance(model, str):
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
if app_label is None or model_name is None:
raise TypeError(
'app_label and model_name must be provided to resolve '
'recursive relationships.'
)
return app_label, model_name
if '.' in model:
app_label, model_name = model.split('.', 1)
return app_label, model_name.lower()
if app_label is None:
raise TypeError(
'app_label must be provided to resolve unscoped model '
'relationships.'
)
return app_label, model.lower()
return model._meta.app_label, model._meta.model_name
FieldReference = namedtuple('FieldReference', 'to through')
def field_references(
model_tuple,
field,
reference_model_tuple,
reference_field_name=None,
reference_field=None,
):
"""
Return either False or a FieldReference if `field` references provided
context.
False positives can be returned if `reference_field_name` is provided
without `reference_field` because of the introspection limitation it
incurs. This should not be an issue when this function is used to determine
whether or not an optimization can take place.
"""
remote_field = field.remote_field
if not remote_field:
return False
references_to = None
references_through = None
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
to_fields = getattr(field, 'to_fields', None)
if (
reference_field_name is None or
# Unspecified to_field(s).
to_fields is None or
# Reference to primary key.
(None in to_fields and (reference_field is None or reference_field.primary_key)) or
# Reference to field.
reference_field_name in to_fields
):
references_to = (remote_field, to_fields)
through = getattr(remote_field, 'through', None)
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
through_fields = remote_field.through_fields
if (
reference_field_name is None or
# Unspecified through_fields.
through_fields is None or
# Reference to field.
reference_field_name in through_fields
):
references_through = (remote_field, through_fields)
if not (references_to or references_through):
return False
return FieldReference(references_to, references_through)
def get_references(state, model_tuple, field_tuple=()):
"""
Generator of (model_state, name, field, reference) referencing
provided context.
If field_tuple is provided only references to this particular field of
model_tuple will be generated.
"""
for state_model_tuple, model_state in state.models.items():
for name, field in model_state.fields.items():
reference = field_references(state_model_tuple, field, model_tuple, *field_tuple)
if reference:
yield model_state, name, field, reference
def field_is_referenced(state, model_tuple, field_tuple):
"""Return whether `field_tuple` is referenced by any state models."""
return next(get_references(state, model_tuple, field_tuple), None) is not None
@@ -28,7 +28,7 @@ class MigrationOptimizer:
"""
# Internal tracking variable for test assertions about # of loops
if app_label is None:
raise TypeError("app_label must be a str.")
raise TypeError('app_label must be a str.')
self._iterations = 0
while True:
result = self.optimize_inner(operations, app_label)
@@ -43,10 +43,10 @@ class MigrationOptimizer:
for i, operation in enumerate(operations):
right = True # Should we reduce on the right or on the left.
# Compare it to each operation after it
for j, other in enumerate(operations[i + 1 :]):
for j, other in enumerate(operations[i + 1:]):
result = operation.reduce(other, app_label)
if isinstance(result, list):
in_between = operations[i + 1 : i + j + 1]
in_between = operations[i + 1:i + j + 1]
if right:
new_operations.extend(in_between)
new_operations.extend(result)
@@ -59,7 +59,7 @@ class MigrationOptimizer:
# Otherwise keep trying.
new_operations.append(operation)
break
new_operations.extend(operations[i + j + 2 :])
new_operations.extend(operations[i + j + 2:])
return new_operations
elif not result:
# Can't perform a right reduction.
@@ -33,7 +33,7 @@ class MigrationQuestioner:
# file check will ensure we skip South ones.
try:
app_config = apps.get_app_config(app_label)
except LookupError: # It's a fake app.
except LookupError: # It's a fake app.
return self.defaults.get("ask_initial", False)
migrations_import_path, _ = MigrationLoader.migrations_module(app_config.label)
if migrations_import_path is None:
@@ -44,6 +44,7 @@ class MigrationQuestioner:
except ImportError:
return self.defaults.get("ask_initial", False)
else:
# getattr() needed on PY36 and older (replace with attribute access).
if getattr(migrations_module, "__file__", None):
filenames = os.listdir(os.path.dirname(migrations_module.__file__))
elif hasattr(migrations_module, "__path__"):
@@ -71,7 +72,7 @@ class MigrationQuestioner:
return self.defaults.get("ask_rename_model", False)
def ask_merge(self, app_label):
"""Should these migrations really be merged?"""
"""Do you really want to merge these migrations?"""
return self.defaults.get("ask_merge", False)
def ask_auto_now_add_addition(self, field_name, model_name):
@@ -81,6 +82,7 @@ class MigrationQuestioner:
class InteractiveMigrationQuestioner(MigrationQuestioner):
def _boolean_input(self, question, default=None):
result = input("%s " % question)
if not result and default is not None:
@@ -104,7 +106,7 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
return value
result = input("Please select a valid option: ")
def _ask_default(self, default=""):
def _ask_default(self, default=''):
"""
Prompt for a default value.
@@ -112,16 +114,13 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
string) which will be shown to the user and used as the return value
if the user doesn't provide any other input.
"""
print("Please enter the default value as valid Python.")
print("Please enter the default value now, as valid Python")
if default:
print(
f"Accept the default '{default}' by pressing 'Enter' or "
f"provide another value."
"You can accept the default '{}' by pressing 'Enter' or you "
"can provide another value.".format(default)
)
print(
"The datetime and django.utils.timezone modules are available, so "
"it is possible to provide e.g. timezone.now as a value."
)
print("The datetime and django.utils.timezone modules are available, so you can do e.g. timezone.now")
print("Type 'exit' to exit this prompt")
while True:
if default:
@@ -132,12 +131,12 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
if not code and default:
code = default
if not code:
print("Please enter some code, or 'exit' (without quotes) to exit.")
print("Please enter some code, or 'exit' (with no quotes) to exit.")
elif code == "exit":
sys.exit(1)
else:
try:
return eval(code, {}, {"datetime": datetime, "timezone": timezone})
return eval(code, {}, {'datetime': datetime, 'timezone': timezone})
except (SyntaxError, NameError) as e:
print("Invalid input: %s" % e)
@@ -145,18 +144,14 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
"""Adding a NOT NULL field to a model."""
if not self.dry_run:
choice = self._choice_input(
f"It is impossible to add a non-nullable field '{field_name}' "
f"to {model_name} without specifying a default. This is "
f"because the database needs something to populate existing "
f"rows.\n"
f"Please select a fix:",
"You are trying to add a non-nullable field '%s' to %s without a default; "
"we can't do that (the database needs something to populate existing rows).\n"
"Please select a fix:" % (field_name, model_name),
[
(
"Provide a one-off default now (will be set on all existing "
"rows with a null value for this column)"
),
"Quit and manually define a default value in models.py.",
],
("Provide a one-off default now (will be set on all existing "
"rows with a null value for this column)"),
"Quit, and let me add a default in models.py",
]
)
if choice == 2:
sys.exit(3)
@@ -168,21 +163,18 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
"""Changing a NULL field to NOT NULL."""
if not self.dry_run:
choice = self._choice_input(
f"It is impossible to change a nullable field '{field_name}' "
f"on {model_name} to non-nullable without providing a "
f"default. This is because the database needs something to "
f"populate existing rows.\n"
f"Please select a fix:",
"You are trying to change the nullable field '%s' on %s to non-nullable "
"without a default; we can't do that (the database needs something to "
"populate existing rows).\n"
"Please select a fix:" % (field_name, model_name),
[
(
"Provide a one-off default now (will be set on all existing "
"rows with a null value for this column)"
),
"Ignore for now. Existing rows that contain NULL values "
"will have to be handled manually, for example with a "
"RunPython or RunSQL operation.",
"Quit and manually define a default value in models.py.",
],
("Provide a one-off default now (will be set on all existing "
"rows with a null value for this column)"),
("Ignore for now, and let me handle existing rows with NULL myself "
"(e.g. because you added a RunPython or RunSQL operation to handle "
"NULL values in a previous data migration)"),
"Quit, and let me add a default in models.py",
]
)
if choice == 2:
return NOT_PROVIDED
@@ -194,33 +186,21 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
def ask_rename(self, model_name, old_name, new_name, field_instance):
"""Was this field really renamed?"""
msg = "Was %s.%s renamed to %s.%s (a %s)? [y/N]"
return self._boolean_input(
msg
% (
model_name,
old_name,
model_name,
new_name,
field_instance.__class__.__name__,
),
False,
)
msg = "Did you rename %s.%s to %s.%s (a %s)? [y/N]"
return self._boolean_input(msg % (model_name, old_name, model_name, new_name,
field_instance.__class__.__name__), False)
def ask_rename_model(self, old_model_state, new_model_state):
"""Was this model really renamed?"""
msg = "Was the model %s.%s renamed to %s? [y/N]"
return self._boolean_input(
msg
% (old_model_state.app_label, old_model_state.name, new_model_state.name),
False,
)
msg = "Did you rename the %s.%s model to %s? [y/N]"
return self._boolean_input(msg % (old_model_state.app_label, old_model_state.name,
new_model_state.name), False)
def ask_merge(self, app_label):
return self._boolean_input(
"\nMerging will only work if the operations printed above do not conflict\n"
+ "with each other (working on different fields or models)\n"
+ "Should these migration branches be merged? [y/N]",
"\nMerging will only work if the operations printed above do not conflict\n" +
"with each other (working on different fields or models)\n" +
"Do you want to merge these migration branches? [y/N]",
False,
)
@@ -228,24 +208,24 @@ class InteractiveMigrationQuestioner(MigrationQuestioner):
"""Adding an auto_now_add field to a model."""
if not self.dry_run:
choice = self._choice_input(
f"It is impossible to add the field '{field_name}' with "
f"'auto_now_add=True' to {model_name} without providing a "
f"default. This is because the database needs something to "
f"populate existing rows.\n",
"You are trying to add the field '{}' with 'auto_now_add=True' "
"to {} without a default; the database needs something to "
"populate existing rows.\n".format(field_name, model_name),
[
"Provide a one-off default now which will be set on all "
"existing rows",
"Quit and manually define a default value in models.py.",
],
"Provide a one-off default now (will be set on all "
"existing rows)",
"Quit, and let me add a default in models.py",
]
)
if choice == 2:
sys.exit(3)
else:
return self._ask_default(default="timezone.now")
return self._ask_default(default='timezone.now')
return None
class NonInteractiveMigrationQuestioner(MigrationQuestioner):
def ask_not_null_addition(self, field_name, model_name):
# We can't ask the user, so act like the user aborted.
sys.exit(3)
@@ -18,7 +18,6 @@ class MigrationRecorder:
If a migration is unapplied its row is removed from the table. Having
a row in the table always means a migration is applied.
"""
_migration_class = None
@classproperty
@@ -28,7 +27,6 @@ class MigrationRecorder:
MigrationRecorder.
"""
if cls._migration_class is None:
class Migration(models.Model):
app = models.CharField(max_length=255)
name = models.CharField(max_length=255)
@@ -36,11 +34,11 @@ class MigrationRecorder:
class Meta:
apps = Apps()
app_label = "migrations"
db_table = "django_migrations"
app_label = 'migrations'
db_table = 'django_migrations'
def __str__(self):
return "Migration %s for %s" % (self.name, self.app)
return 'Migration %s for %s' % (self.name, self.app)
cls._migration_class = Migration
return cls._migration_class
@@ -69,9 +67,7 @@ class MigrationRecorder:
with self.connection.schema_editor() as editor:
editor.create_model(self.Migration)
except DatabaseError as exc:
raise MigrationSchemaMissing(
"Unable to create the django_migrations table (%s)" % exc
)
raise MigrationSchemaMissing("Unable to create the django_migrations table (%s)" % exc)
def applied_migrations(self):
"""
@@ -79,10 +75,7 @@ class MigrationRecorder:
for all applied migrations.
"""
if self.has_table():
return {
(migration.app, migration.name): migration
for migration in self.migration_qs
}
return {(migration.app, migration.name): migration for migration in self.migration_qs}
else:
# If the django_migrations table doesn't exist, then no migrations
# are applied.
@@ -25,16 +25,12 @@ class BaseSerializer:
self.value = value
def serialize(self):
raise NotImplementedError(
"Subclasses of BaseSerializer must implement the serialize() method."
)
raise NotImplementedError('Subclasses of BaseSerializer must implement the serialize() method.')
class BaseSequenceSerializer(BaseSerializer):
def _format(self):
raise NotImplementedError(
"Subclasses of BaseSequenceSerializer must implement the _format() method."
)
raise NotImplementedError('Subclasses of BaseSequenceSerializer must implement the _format() method.')
def serialize(self):
imports = set()
@@ -59,21 +55,19 @@ class ChoicesSerializer(BaseSerializer):
class DateTimeSerializer(BaseSerializer):
"""For datetime.*, except datetime.datetime."""
def serialize(self):
return repr(self.value), {"import datetime"}
return repr(self.value), {'import datetime'}
class DatetimeDatetimeSerializer(BaseSerializer):
"""For datetime.datetime."""
def serialize(self):
if self.value.tzinfo is not None and self.value.tzinfo != utc:
self.value = self.value.astimezone(utc)
imports = ["import datetime"]
if self.value.tzinfo is not None:
imports.append("from django.utils.timezone import utc")
return repr(self.value).replace("datetime.timezone.utc", "utc"), set(imports)
return repr(self.value).replace('<UTC>', 'utc'), set(imports)
class DecimalSerializer(BaseSerializer):
@@ -129,8 +123,8 @@ class EnumSerializer(BaseSerializer):
enum_class = self.value.__class__
module = enum_class.__module__
return (
"%s.%s[%r]" % (module, enum_class.__qualname__, self.value.name),
{"import %s" % module},
'%s.%s[%r]' % (module, enum_class.__qualname__, self.value.name),
{'import %s' % module},
)
@@ -148,29 +142,23 @@ class FrozensetSerializer(BaseSequenceSerializer):
class FunctionTypeSerializer(BaseSerializer):
def serialize(self):
if getattr(self.value, "__self__", None) and isinstance(
self.value.__self__, type
):
if getattr(self.value, "__self__", None) and isinstance(self.value.__self__, type):
klass = self.value.__self__
module = klass.__module__
return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {
"import %s" % module
}
return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {"import %s" % module}
# Further error checking
if self.value.__name__ == "<lambda>":
if self.value.__name__ == '<lambda>':
raise ValueError("Cannot serialize function: lambda")
if self.value.__module__ is None:
raise ValueError("Cannot serialize function %r: No module" % self.value)
module_name = self.value.__module__
if "<" not in self.value.__qualname__: # Qualname can include <locals>
return "%s.%s" % (module_name, self.value.__qualname__), {
"import %s" % self.value.__module__
}
if '<' not in self.value.__qualname__: # Qualname can include <locals>
return '%s.%s' % (module_name, self.value.__qualname__), {'import %s' % self.value.__module__}
raise ValueError(
"Could not find function %s in %s.\n" % (self.value.__name__, module_name)
'Could not find function %s in %s.\n' % (self.value.__name__, module_name)
)
@@ -179,14 +167,11 @@ class FunctoolsPartialSerializer(BaseSerializer):
# Serialize functools.partial() arguments
func_string, func_imports = serializer_factory(self.value.func).serialize()
args_string, args_imports = serializer_factory(self.value.args).serialize()
keywords_string, keywords_imports = serializer_factory(
self.value.keywords
).serialize()
keywords_string, keywords_imports = serializer_factory(self.value.keywords).serialize()
# Add any imports needed by arguments
imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
imports = {'import functools', *func_imports, *args_imports, *keywords_imports}
return (
"functools.%s(%s, *%s, **%s)"
% (
'functools.%s(%s, *%s, **%s)' % (
self.value.__class__.__name__,
func_string,
args_string,
@@ -229,10 +214,9 @@ class ModelManagerSerializer(DeconstructableSerializer):
class OperationSerializer(BaseSerializer):
def serialize(self):
from django.db.migrations.writer import OperationWriter
string, imports = OperationWriter(self.value, indentation=0).serialize()
# Nested operation, trailing comma is handled in upper OperationWriter._write()
return string.rstrip(","), imports
return string.rstrip(','), imports
class PathLikeSerializer(BaseSerializer):
@@ -244,24 +228,22 @@ class PathSerializer(BaseSerializer):
def serialize(self):
# Convert concrete paths to pure paths to avoid issues with migrations
# generated on one platform being used on a different platform.
prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
return "pathlib.%s%r" % (prefix, self.value), {"import pathlib"}
prefix = 'Pure' if isinstance(self.value, pathlib.Path) else ''
return 'pathlib.%s%r' % (prefix, self.value), {'import pathlib'}
class RegexSerializer(BaseSerializer):
def serialize(self):
regex_pattern, pattern_imports = serializer_factory(
self.value.pattern
).serialize()
regex_pattern, pattern_imports = serializer_factory(self.value.pattern).serialize()
# Turn off default implicit flags (e.g. re.U) because regexes with the
# same implicit and explicit flags aren't equal.
flags = self.value.flags ^ re.compile("").flags
flags = self.value.flags ^ re.compile('').flags
regex_flags, flag_imports = serializer_factory(flags).serialize()
imports = {"import re", *pattern_imports, *flag_imports}
imports = {'import re', *pattern_imports, *flag_imports}
args = [regex_pattern]
if flags:
args.append(regex_flags)
return "re.compile(%s)" % ", ".join(args), imports
return "re.compile(%s)" % ', '.join(args), imports
class SequenceSerializer(BaseSequenceSerializer):
@@ -273,14 +255,12 @@ class SetSerializer(BaseSequenceSerializer):
def _format(self):
# Serialize as a set literal except when value is empty because {}
# is an empty dict.
return "{%s}" if self.value else "set(%s)"
return '{%s}' if self.value else 'set(%s)'
class SettingsReferenceSerializer(BaseSerializer):
def serialize(self):
return "settings.%s" % self.value.setting_name, {
"from django.conf import settings"
}
return "settings.%s" % self.value.setting_name, {"from django.conf import settings"}
class TupleSerializer(BaseSequenceSerializer):
@@ -293,8 +273,8 @@ class TupleSerializer(BaseSequenceSerializer):
class TypeSerializer(BaseSerializer):
def serialize(self):
special_cases = [
(models.Model, "models.Model", ["from django.db import models"]),
(type(None), "type(None)", []),
(models.Model, "models.Model", []),
(type(None), 'type(None)', []),
]
for case, string, imports in special_cases:
if case is self.value:
@@ -304,9 +284,7 @@ class TypeSerializer(BaseSerializer):
if module == builtins.__name__:
return self.value.__name__, set()
else:
return "%s.%s" % (module, self.value.__qualname__), {
"import %s" % module
}
return "%s.%s" % (module, self.value.__qualname__), {"import %s" % module}
class UUIDSerializer(BaseSerializer):
@@ -331,11 +309,7 @@ class Serializer:
(bool, int, type(None), bytes, str, range): BaseSimpleSerializer,
decimal.Decimal: DecimalSerializer,
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
(
types.FunctionType,
types.BuiltinFunctionType,
types.MethodType,
): FunctionTypeSerializer,
(types.FunctionType, types.BuiltinFunctionType, types.MethodType): FunctionTypeSerializer,
collections.abc.Iterable: IterableSerializer,
(COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
uuid.UUID: UUIDSerializer,
@@ -346,9 +320,7 @@ class Serializer:
@classmethod
def register(cls, type_, serializer):
if not issubclass(serializer, BaseSerializer):
raise ValueError(
"'%s' must inherit from 'BaseSerializer'." % serializer.__name__
)
raise ValueError("'%s' must inherit from 'BaseSerializer'." % serializer.__name__)
cls._registry[type_] = serializer
@classmethod
@@ -373,7 +345,7 @@ def serializer_factory(value):
if isinstance(value, type):
return TypeSerializer(value)
# Anything that knows how to deconstruct itself.
if hasattr(value, "deconstruct"):
if hasattr(value, 'deconstruct'):
return DeconstructableSerializer(value)
for type_, serializer_cls in Serializer._registry.items():
if isinstance(value, type_):
@@ -1,16 +1,10 @@
import copy
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from django.apps import AppConfig
from django.apps.registry import Apps
from django.apps.registry import apps as global_apps
from django.apps.registry import Apps, apps as global_apps
from django.conf import settings
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.migrations.utils import field_is_referenced, get_references
from django.db.models import NOT_PROVIDED
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
from django.db.models.options import DEFAULT_NAMES, normalize_together
from django.db.models.utils import make_model_tuple
@@ -19,12 +13,11 @@ from django.utils.module_loading import import_string
from django.utils.version import get_docs_version
from .exceptions import InvalidBasesError
from .utils import resolve_relation
def _get_app_label_and_model_name(model, app_label=""):
def _get_app_label_and_model_name(model, app_label=''):
if isinstance(model, str):
split = model.split(".", 1)
split = model.split('.', 1)
return tuple(split) if len(split) == 2 else (app_label, split[0])
else:
return model._meta.app_label, model._meta.model_name
@@ -33,17 +26,12 @@ def _get_app_label_and_model_name(model, app_label=""):
def _get_related_models(m):
"""Return all models that have a direct relationship to the given model."""
related_models = [
subclass
for subclass in m.__subclasses__()
subclass for subclass in m.__subclasses__()
if issubclass(subclass, models.Model)
]
related_fields_models = set()
for f in m._meta.get_fields(include_parents=True, include_hidden=True):
if (
f.is_relation
and f.related_model is not None
and not isinstance(f.related_model, str)
):
if f.is_relation and f.related_model is not None and not isinstance(f.related_model, str):
related_fields_models.add(f.model)
related_models.append(f.related_model)
# Reverse accessors of foreign keys to proxy models are attached to their
@@ -79,10 +67,7 @@ def get_related_models_recursive(model):
seen = set()
queue = _get_related_models(model)
for rel_mod in queue:
rel_app_label, rel_model_name = (
rel_mod._meta.app_label,
rel_mod._meta.model_name,
)
rel_app_label, rel_model_name = rel_mod._meta.app_label, rel_mod._meta.model_name
if (rel_app_label, rel_model_name) in seen:
continue
seen.add((rel_app_label, rel_model_name))
@@ -100,228 +85,23 @@ class ProjectState:
def __init__(self, models=None, real_apps=None):
self.models = models or {}
# Apps to include from main registry, usually unmigrated ones
if real_apps is None:
real_apps = set()
else:
assert isinstance(real_apps, set)
self.real_apps = real_apps
self.real_apps = real_apps or []
self.is_delayed = False
# {remote_model_key: {model_key: {field_name: field}}}
self._relations = None
@property
def relations(self):
if self._relations is None:
self.resolve_fields_and_relations()
return self._relations
def add_model(self, model_state):
model_key = model_state.app_label, model_state.name_lower
self.models[model_key] = model_state
if self._relations is not None:
self.resolve_model_relations(model_key)
if "apps" in self.__dict__: # hasattr would cache the property
self.reload_model(*model_key)
app_label, model_name = model_state.app_label, model_state.name_lower
self.models[(app_label, model_name)] = model_state
if 'apps' in self.__dict__: # hasattr would cache the property
self.reload_model(app_label, model_name)
def remove_model(self, app_label, model_name):
model_key = app_label, model_name
del self.models[model_key]
if self._relations is not None:
self._relations.pop(model_key, None)
# Call list() since _relations can change size during iteration.
for related_model_key, model_relations in list(self._relations.items()):
model_relations.pop(model_key, None)
if not model_relations:
del self._relations[related_model_key]
if "apps" in self.__dict__: # hasattr would cache the property
self.apps.unregister_model(*model_key)
del self.models[app_label, model_name]
if 'apps' in self.__dict__: # hasattr would cache the property
self.apps.unregister_model(app_label, model_name)
# Need to do this explicitly since unregister_model() doesn't clear
# the cache automatically (#24513)
self.apps.clear_cache()
def rename_model(self, app_label, old_name, new_name):
# Add a new model.
old_name_lower = old_name.lower()
new_name_lower = new_name.lower()
renamed_model = self.models[app_label, old_name_lower].clone()
renamed_model.name = new_name
self.models[app_label, new_name_lower] = renamed_model
# Repoint all fields pointing to the old model to the new one.
old_model_tuple = (app_label, old_name_lower)
new_remote_model = f"{app_label}.{new_name}"
to_reload = set()
for model_state, name, field, reference in get_references(
self, old_model_tuple
):
changed_field = None
if reference.to:
changed_field = field.clone()
changed_field.remote_field.model = new_remote_model
if reference.through:
if changed_field is None:
changed_field = field.clone()
changed_field.remote_field.through = new_remote_model
if changed_field:
model_state.fields[name] = changed_field
to_reload.add((model_state.app_label, model_state.name_lower))
if self._relations is not None:
old_name_key = app_label, old_name_lower
new_name_key = app_label, new_name_lower
if old_name_key in self._relations:
self._relations[new_name_key] = self._relations.pop(old_name_key)
for model_relations in self._relations.values():
if old_name_key in model_relations:
model_relations[new_name_key] = model_relations.pop(old_name_key)
# Reload models related to old model before removing the old model.
self.reload_models(to_reload, delay=True)
# Remove the old model.
self.remove_model(app_label, old_name_lower)
self.reload_model(app_label, new_name_lower, delay=True)
def alter_model_options(self, app_label, model_name, options, option_keys=None):
model_state = self.models[app_label, model_name]
model_state.options = {**model_state.options, **options}
if option_keys:
for key in option_keys:
if key not in options:
model_state.options.pop(key, False)
self.reload_model(app_label, model_name, delay=True)
def alter_model_managers(self, app_label, model_name, managers):
model_state = self.models[app_label, model_name]
model_state.managers = list(managers)
self.reload_model(app_label, model_name, delay=True)
def _append_option(self, app_label, model_name, option_name, obj):
model_state = self.models[app_label, model_name]
model_state.options[option_name] = [*model_state.options[option_name], obj]
self.reload_model(app_label, model_name, delay=True)
def _remove_option(self, app_label, model_name, option_name, obj_name):
model_state = self.models[app_label, model_name]
objs = model_state.options[option_name]
model_state.options[option_name] = [obj for obj in objs if obj.name != obj_name]
self.reload_model(app_label, model_name, delay=True)
def add_index(self, app_label, model_name, index):
self._append_option(app_label, model_name, "indexes", index)
def remove_index(self, app_label, model_name, index_name):
self._remove_option(app_label, model_name, "indexes", index_name)
def add_constraint(self, app_label, model_name, constraint):
self._append_option(app_label, model_name, "constraints", constraint)
def remove_constraint(self, app_label, model_name, constraint_name):
self._remove_option(app_label, model_name, "constraints", constraint_name)
def add_field(self, app_label, model_name, name, field, preserve_default):
# If preserve default is off, don't use the default for future state.
if not preserve_default:
field = field.clone()
field.default = NOT_PROVIDED
else:
field = field
model_key = app_label, model_name
self.models[model_key].fields[name] = field
if self._relations is not None:
self.resolve_model_field_relations(model_key, name, field)
# Delay rendering of relationships if it's not a relational field.
delay = not field.is_relation
self.reload_model(*model_key, delay=delay)
def remove_field(self, app_label, model_name, name):
model_key = app_label, model_name
model_state = self.models[model_key]
old_field = model_state.fields.pop(name)
if self._relations is not None:
self.resolve_model_field_relations(model_key, name, old_field)
# Delay rendering of relationships if it's not a relational field.
delay = not old_field.is_relation
self.reload_model(*model_key, delay=delay)
def alter_field(self, app_label, model_name, name, field, preserve_default):
if not preserve_default:
field = field.clone()
field.default = NOT_PROVIDED
else:
field = field
model_key = app_label, model_name
fields = self.models[model_key].fields
if self._relations is not None:
old_field = fields.pop(name)
if old_field.is_relation:
self.resolve_model_field_relations(model_key, name, old_field)
fields[name] = field
if field.is_relation:
self.resolve_model_field_relations(model_key, name, field)
else:
fields[name] = field
# TODO: investigate if old relational fields must be reloaded or if
# it's sufficient if the new field is (#27737).
# Delay rendering of relationships if it's not a relational field and
# not referenced by a foreign key.
delay = not field.is_relation and not field_is_referenced(
self, model_key, (name, field)
)
self.reload_model(*model_key, delay=delay)
def rename_field(self, app_label, model_name, old_name, new_name):
model_key = app_label, model_name
model_state = self.models[model_key]
# Rename the field.
fields = model_state.fields
try:
found = fields.pop(old_name)
except KeyError:
raise FieldDoesNotExist(
f"{app_label}.{model_name} has no field named '{old_name}'"
)
fields[new_name] = found
for field in fields.values():
# Fix from_fields to refer to the new field.
from_fields = getattr(field, "from_fields", None)
if from_fields:
field.from_fields = tuple(
[
new_name if from_field_name == old_name else from_field_name
for from_field_name in from_fields
]
)
# Fix index/unique_together to refer to the new field.
options = model_state.options
for option in ("index_together", "unique_together"):
if option in options:
options[option] = [
[new_name if n == old_name else n for n in together]
for together in options[option]
]
# Fix to_fields to refer to the new field.
delay = True
references = get_references(self, model_key, (old_name, found))
for *_, field, reference in references:
delay = False
if reference.to:
remote_field, to_fields = reference.to
if getattr(remote_field, "field_name", None) == old_name:
remote_field.field_name = new_name
if to_fields:
field.to_fields = tuple(
[
new_name if to_field_name == old_name else to_field_name
for to_field_name in to_fields
]
)
if self._relations is not None:
old_name_lower = old_name.lower()
new_name_lower = new_name.lower()
for to_model in self._relations.values():
if old_name_lower in to_model[model_key]:
field = to_model[model_key].pop(old_name_lower)
field.name = new_name_lower
to_model[model_key][new_name_lower] = field
self.reload_model(*model_key, delay=delay)
def _find_reload_model(self, app_label, model_name, delay=False):
if delay:
self.is_delayed = True
@@ -349,9 +129,7 @@ class ProjectState:
if field.is_relation:
if field.remote_field.model == RECURSIVE_RELATIONSHIP_CONSTANT:
continue
rel_app_label, rel_model_name = _get_app_label_and_model_name(
field.related_model, app_label
)
rel_app_label, rel_model_name = _get_app_label_and_model_name(field.related_model, app_label)
direct_related_models.add((rel_app_label, rel_model_name.lower()))
# For all direct related models recursively get all related models.
@@ -373,17 +151,15 @@ class ProjectState:
return related_models
def reload_model(self, app_label, model_name, delay=False):
if "apps" in self.__dict__: # hasattr would cache the property
if 'apps' in self.__dict__: # hasattr would cache the property
related_models = self._find_reload_model(app_label, model_name, delay)
self._reload(related_models)
def reload_models(self, models, delay=True):
if "apps" in self.__dict__: # hasattr would cache the property
if 'apps' in self.__dict__: # hasattr would cache the property
related_models = set()
for app_label, model_name in models:
related_models.update(
self._find_reload_model(app_label, model_name, delay)
)
related_models.update(self._find_reload_model(app_label, model_name, delay))
self._reload(related_models)
def _reload(self, related_models):
@@ -412,137 +188,30 @@ class ProjectState:
# Render all models
self.apps.render_multiple(states_to_be_rendered)
def update_model_field_relation(
self,
model,
model_key,
field_name,
field,
concretes,
):
remote_model_key = resolve_relation(model, *model_key)
if remote_model_key[0] not in self.real_apps and remote_model_key in concretes:
remote_model_key = concretes[remote_model_key]
relations_to_remote_model = self._relations[remote_model_key]
if field_name in self.models[model_key].fields:
# The assert holds because it's a new relation, or an altered
# relation, in which case references have been removed by
# alter_field().
assert field_name not in relations_to_remote_model[model_key]
relations_to_remote_model[model_key][field_name] = field
else:
del relations_to_remote_model[model_key][field_name]
if not relations_to_remote_model[model_key]:
del relations_to_remote_model[model_key]
def resolve_model_field_relations(
self,
model_key,
field_name,
field,
concretes=None,
):
remote_field = field.remote_field
if not remote_field:
return
if concretes is None:
concretes, _ = self._get_concrete_models_mapping_and_proxy_models()
self.update_model_field_relation(
remote_field.model,
model_key,
field_name,
field,
concretes,
)
through = getattr(remote_field, "through", None)
if not through:
return
self.update_model_field_relation(
through, model_key, field_name, field, concretes
)
def resolve_model_relations(self, model_key, concretes=None):
if concretes is None:
concretes, _ = self._get_concrete_models_mapping_and_proxy_models()
model_state = self.models[model_key]
for field_name, field in model_state.fields.items():
self.resolve_model_field_relations(model_key, field_name, field, concretes)
def resolve_fields_and_relations(self):
# Resolve fields.
for model_state in self.models.values():
for field_name, field in model_state.fields.items():
field.name = field_name
# Resolve relations.
# {remote_model_key: {model_key: {field_name: field}}}
self._relations = defaultdict(partial(defaultdict, dict))
concretes, proxies = self._get_concrete_models_mapping_and_proxy_models()
for model_key in concretes:
self.resolve_model_relations(model_key, concretes)
for model_key in proxies:
self._relations[model_key] = self._relations[concretes[model_key]]
def get_concrete_model_key(self, model):
(
concrete_models_mapping,
_,
) = self._get_concrete_models_mapping_and_proxy_models()
model_key = make_model_tuple(model)
return concrete_models_mapping[model_key]
def _get_concrete_models_mapping_and_proxy_models(self):
concrete_models_mapping = {}
proxy_models = {}
# Split models to proxy and concrete models.
for model_key, model_state in self.models.items():
if model_state.options.get("proxy"):
proxy_models[model_key] = model_state
# Find a concrete model for the proxy.
concrete_models_mapping[
model_key
] = self._find_concrete_model_from_proxy(
proxy_models,
model_state,
)
else:
concrete_models_mapping[model_key] = model_key
return concrete_models_mapping, proxy_models
def _find_concrete_model_from_proxy(self, proxy_models, model_state):
for base in model_state.bases:
if not (isinstance(base, str) or issubclass(base, models.Model)):
continue
base_key = make_model_tuple(base)
base_state = proxy_models.get(base_key)
if not base_state:
# Concrete model found, stop looking at bases.
return base_key
return self._find_concrete_model_from_proxy(proxy_models, base_state)
def clone(self):
"""Return an exact copy of this ProjectState."""
new_state = ProjectState(
models={k: v.clone() for k, v in self.models.items()},
real_apps=self.real_apps,
)
if "apps" in self.__dict__:
if 'apps' in self.__dict__:
new_state.apps = self.apps.clone()
new_state.is_delayed = self.is_delayed
return new_state
def clear_delayed_apps_cache(self):
if self.is_delayed and "apps" in self.__dict__:
del self.__dict__["apps"]
if self.is_delayed and 'apps' in self.__dict__:
del self.__dict__['apps']
@cached_property
def apps(self):
return StateApps(self.real_apps, self.models)
@property
def concrete_apps(self):
self.apps = StateApps(self.real_apps, self.models, ignore_swappable=True)
return self.apps
@classmethod
def from_apps(cls, apps):
"""Take an Apps and return a ProjectState matching it."""
@@ -553,12 +222,11 @@ class ProjectState:
return cls(app_models)
def __eq__(self, other):
return self.models == other.models and self.real_apps == other.real_apps
return self.models == other.models and set(self.real_apps) == set(other.real_apps)
class AppConfigStub(AppConfig):
"""Stub of an AppConfig. Only provides a label and a dict of models."""
def __init__(self, label):
self.apps = None
self.models = {}
@@ -577,7 +245,6 @@ class StateApps(Apps):
Subclass of the global Apps registry class to better handle dynamic model
additions and removals.
"""
def __init__(self, real_apps, models, ignore_swappable=False):
# Any apps in self.real_apps should have all their models included
# in the render. We don't use the original model instances as there
@@ -591,9 +258,7 @@ class StateApps(Apps):
self.real_models.append(ModelState.from_model(model, exclude_rels=True))
# Populate the app registry with a stub for each application.
app_labels = {model_state.app_label for model_state in models.values()}
app_configs = [
AppConfigStub(label) for label in sorted([*real_apps, *app_labels])
]
app_configs = [AppConfigStub(label) for label in sorted([*real_apps, *app_labels])]
super().__init__(app_configs)
# These locks get in the way of copying as implemented in clone(),
@@ -606,10 +271,7 @@ class StateApps(Apps):
# There shouldn't be any operations pending at this point.
from django.core.checks.model_checks import _check_lazy_references
ignore = (
{make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
)
ignore = {make_model_tuple(settings.AUTH_USER_MODEL)} if ignore_swappable else set()
errors = _check_lazy_references(self, ignore=ignore)
if errors:
raise ValueError("\n".join(error.msg for error in errors))
@@ -645,12 +307,10 @@ class StateApps(Apps):
new_unrendered_models.append(model)
if len(new_unrendered_models) == len(unrendered_models):
raise InvalidBasesError(
"Cannot resolve bases for %r\nThis can happen if you are "
"inheriting models from an app with migrations (e.g. "
"contrib.auth)\n in an app with no migrations; see "
"https://docs.djangoproject.com/en/%s/topics/migrations/"
"#dependencies for more"
% (new_unrendered_models, get_docs_version())
"Cannot resolve bases for %r\nThis can happen if you are inheriting models from an "
"app with migrations (e.g. contrib.auth)\n in an app with no migrations; see "
"https://docs.djangoproject.com/en/%s/topics/migrations/#dependencies "
"for more" % (new_unrendered_models, get_docs_version())
)
unrendered_models = new_unrendered_models
@@ -694,36 +354,34 @@ class ModelState:
assign new ones, as these are not detached during a clone.
"""
def __init__(
self, app_label, name, fields, options=None, bases=None, managers=None
):
def __init__(self, app_label, name, fields, options=None, bases=None, managers=None):
self.app_label = app_label
self.name = name
self.fields = dict(fields)
self.options = options or {}
self.options.setdefault("indexes", [])
self.options.setdefault("constraints", [])
self.options.setdefault('indexes', [])
self.options.setdefault('constraints', [])
self.bases = bases or (models.Model,)
self.managers = managers or []
for name, field in self.fields.items():
# Sanity-check that fields are NOT already bound to a model.
if hasattr(field, "model"):
if hasattr(field, 'model'):
raise ValueError(
'ModelState.fields cannot be bound to a model - "%s" is.' % name
)
# Sanity-check that relation fields are NOT referring to a model class.
if field.is_relation and hasattr(field.related_model, "_meta"):
if field.is_relation and hasattr(field.related_model, '_meta'):
raise ValueError(
'ModelState.fields cannot refer to a model class - "%s.to" does. '
"Use a string reference instead." % name
'Use a string reference instead.' % name
)
if field.many_to_many and hasattr(field.remote_field.through, "_meta"):
if field.many_to_many and hasattr(field.remote_field.through, '_meta'):
raise ValueError(
'ModelState.fields cannot refer to a model class - "%s.through" '
"does. Use a string reference instead." % name
'ModelState.fields cannot refer to a model class - "%s.through" does. '
'Use a string reference instead.' % name
)
# Sanity-check that indexes have their name set.
for index in self.options["indexes"]:
for index in self.options['indexes']:
if not index.name:
raise ValueError(
"Indexes passed to ModelState require a name attribute. "
@@ -734,11 +392,6 @@ class ModelState:
def name_lower(self):
return self.name.lower()
def get_field(self, field_name):
if field_name == "_order":
field_name = self.options.get("order_with_respect_to", field_name)
return self.fields[field_name]
@classmethod
def from_model(cls, model, exclude_rels=False):
"""Given a model, return a ModelState representing it."""
@@ -753,28 +406,22 @@ class ModelState:
try:
fields.append((name, field.clone()))
except TypeError as e:
raise TypeError(
"Couldn't reconstruct field %s on %s: %s"
% (
name,
model._meta.label,
e,
)
)
raise TypeError("Couldn't reconstruct field %s on %s: %s" % (
name,
model._meta.label,
e,
))
if not exclude_rels:
for field in model._meta.local_many_to_many:
name = field.name
try:
fields.append((name, field.clone()))
except TypeError as e:
raise TypeError(
"Couldn't reconstruct m2m field %s on %s: %s"
% (
name,
model._meta.object_name,
e,
)
)
raise TypeError("Couldn't reconstruct m2m field %s on %s: %s" % (
name,
model._meta.object_name,
e,
))
# Extract the options
options = {}
for name in DEFAULT_NAMES:
@@ -793,11 +440,9 @@ class ModelState:
for index in indexes:
if not index.name:
index.set_name_with_model(model)
options["indexes"] = indexes
elif name == "constraints":
options["constraints"] = [
con.clone() for con in model._meta.constraints
]
options['indexes'] = indexes
elif name == 'constraints':
options['constraints'] = [con.clone() for con in model._meta.constraints]
else:
options[name] = model._meta.original_attrs[name]
# If we're ignoring relationships, remove all field-listing model
@@ -807,10 +452,8 @@ class ModelState:
if key in options:
del options[key]
# Private fields are ignored, so remove options that refer to them.
elif options.get("order_with_respect_to") in {
field.name for field in model._meta.private_fields
}:
del options["order_with_respect_to"]
elif options.get('order_with_respect_to') in {field.name for field in model._meta.private_fields}:
del options['order_with_respect_to']
def flatten_bases(model):
bases = []
@@ -826,19 +469,19 @@ class ModelState:
# __bases__ we may end up with duplicates and ordering issues, we
# therefore discard any duplicates and reorder the bases according
# to their index in the MRO.
flattened_bases = sorted(
set(flatten_bases(model)), key=lambda x: model.__mro__.index(x)
)
flattened_bases = sorted(set(flatten_bases(model)), key=lambda x: model.__mro__.index(x))
# Make our record
bases = tuple(
(base._meta.label_lower if hasattr(base, "_meta") else base)
(
base._meta.label_lower
if hasattr(base, "_meta") else
base
)
for base in flattened_bases
)
# Ensure at least one base inherits from models.Model
if not any(
(isinstance(base, str) or issubclass(base, models.Model)) for base in bases
):
if not any((isinstance(base, str) or issubclass(base, models.Model)) for base in bases):
bases = (models.Model,)
managers = []
@@ -865,7 +508,7 @@ class ModelState:
managers.append((manager.name, new_manager))
# Ignore a shimmed default manager called objects if it's the only one.
if managers == [("objects", default_manager_shim)]:
if managers == [('objects', default_manager_shim)]:
managers = []
# Construct the new ModelState
@@ -908,7 +551,7 @@ class ModelState:
def render(self, apps):
"""Create a Model object from our current state into the given apps."""
# First, make a Meta object
meta_contents = {"app_label": self.app_label, "apps": apps, **self.options}
meta_contents = {'app_label': self.app_label, 'apps': apps, **self.options}
meta = type("Meta", (), meta_contents)
# Then, work out our bases
try:
@@ -917,13 +560,11 @@ class ModelState:
for base in self.bases
)
except LookupError:
raise InvalidBasesError(
"Cannot resolve one or more bases from %r" % (self.bases,)
)
raise InvalidBasesError("Cannot resolve one or more bases from %r" % (self.bases,))
# Clone fields for the body, add other bits.
body = {name: field.clone() for name, field in self.fields.items()}
body["Meta"] = meta
body["__module__"] = "__fake__"
body['Meta'] = meta
body['__module__'] = "__fake__"
# Restore managers
body.update(self.construct_managers())
@@ -931,33 +572,33 @@ class ModelState:
return type(self.name, bases, body)
def get_index_by_name(self, name):
for index in self.options["indexes"]:
for index in self.options['indexes']:
if index.name == name:
return index
raise ValueError("No index named %s on model %s" % (name, self.name))
def get_constraint_by_name(self, name):
for constraint in self.options["constraints"]:
for constraint in self.options['constraints']:
if constraint.name == name:
return constraint
raise ValueError("No constraint named %s on model %s" % (name, self.name))
raise ValueError('No constraint named %s on model %s' % (name, self.name))
def __repr__(self):
return "<%s: '%s.%s'>" % (self.__class__.__name__, self.app_label, self.name)
def __eq__(self, other):
return (
(self.app_label == other.app_label)
and (self.name == other.name)
and (len(self.fields) == len(other.fields))
and all(
(self.app_label == other.app_label) and
(self.name == other.name) and
(len(self.fields) == len(other.fields)) and
all(
k1 == k2 and f1.deconstruct()[1:] == f2.deconstruct()[1:]
for (k1, f1), (k2, f2) in zip(
sorted(self.fields.items()),
sorted(other.fields.items()),
)
)
and (self.options == other.options)
and (self.bases == other.bases)
and (self.managers == other.managers)
) and
(self.options == other.options) and
(self.bases == other.bases) and
(self.managers == other.managers)
)
@@ -1,12 +1,7 @@
import datetime
import re
from collections import namedtuple
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
FieldReference = namedtuple("FieldReference", "to through")
COMPILED_REGEX_TYPE = type(re.compile(""))
COMPILED_REGEX_TYPE = type(re.compile(''))
class RegexObject:
@@ -20,108 +15,3 @@ class RegexObject:
def get_migration_name_timestamp():
return datetime.datetime.now().strftime("%Y%m%d_%H%M")
def resolve_relation(model, app_label=None, model_name=None):
"""
Turn a model class or model reference string and return a model tuple.
app_label and model_name are used to resolve the scope of recursive and
unscoped model relationship.
"""
if isinstance(model, str):
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
if app_label is None or model_name is None:
raise TypeError(
"app_label and model_name must be provided to resolve "
"recursive relationships."
)
return app_label, model_name
if "." in model:
app_label, model_name = model.split(".", 1)
return app_label, model_name.lower()
if app_label is None:
raise TypeError(
"app_label must be provided to resolve unscoped model relationships."
)
return app_label, model.lower()
return model._meta.app_label, model._meta.model_name
def field_references(
model_tuple,
field,
reference_model_tuple,
reference_field_name=None,
reference_field=None,
):
"""
Return either False or a FieldReference if `field` references provided
context.
False positives can be returned if `reference_field_name` is provided
without `reference_field` because of the introspection limitation it
incurs. This should not be an issue when this function is used to determine
whether or not an optimization can take place.
"""
remote_field = field.remote_field
if not remote_field:
return False
references_to = None
references_through = None
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
to_fields = getattr(field, "to_fields", None)
if (
reference_field_name is None
or
# Unspecified to_field(s).
to_fields is None
or
# Reference to primary key.
(
None in to_fields
and (reference_field is None or reference_field.primary_key)
)
or
# Reference to field.
reference_field_name in to_fields
):
references_to = (remote_field, to_fields)
through = getattr(remote_field, "through", None)
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
through_fields = remote_field.through_fields
if (
reference_field_name is None
or
# Unspecified through_fields.
through_fields is None
or
# Reference to field.
reference_field_name in through_fields
):
references_through = (remote_field, through_fields)
if not (references_to or references_through):
return False
return FieldReference(references_to, references_through)
def get_references(state, model_tuple, field_tuple=()):
"""
Generator of (model_state, name, field, reference) referencing
provided context.
If field_tuple is provided only references to this particular field of
model_tuple will be generated.
"""
for state_model_tuple, model_state in state.models.items():
for name, field in model_state.fields.items():
reference = field_references(
state_model_tuple, field, model_tuple, *field_tuple
)
if reference:
yield model_state, name, field, reference
def field_is_referenced(state, model_tuple, field_tuple):
"""Return whether `field_tuple` is referenced by any state models."""
return next(get_references(state, model_tuple, field_tuple), None) is not None
@@ -1,10 +1,10 @@
import os
import re
from importlib import import_module
from django import get_version
from django.apps import apps
# SettingsReference imported for backwards compatibility in Django 2.2.
from django.conf import SettingsReference # NOQA
from django.db import migrations
@@ -22,30 +22,30 @@ class OperationWriter:
self.indentation = indentation
def serialize(self):
def _write(_arg_name, _arg_value):
if _arg_name in self.operation.serialization_expand_args and isinstance(
_arg_value, (list, tuple, dict)
):
if (_arg_name in self.operation.serialization_expand_args and
isinstance(_arg_value, (list, tuple, dict))):
if isinstance(_arg_value, dict):
self.feed("%s={" % _arg_name)
self.feed('%s={' % _arg_name)
self.indent()
for key, value in _arg_value.items():
key_string, key_imports = MigrationWriter.serialize(key)
arg_string, arg_imports = MigrationWriter.serialize(value)
args = arg_string.splitlines()
if len(args) > 1:
self.feed("%s: %s" % (key_string, args[0]))
self.feed('%s: %s' % (key_string, args[0]))
for arg in args[1:-1]:
self.feed(arg)
self.feed("%s," % args[-1])
self.feed('%s,' % args[-1])
else:
self.feed("%s: %s," % (key_string, arg_string))
self.feed('%s: %s,' % (key_string, arg_string))
imports.update(key_imports)
imports.update(arg_imports)
self.unindent()
self.feed("},")
self.feed('},')
else:
self.feed("%s=[" % _arg_name)
self.feed('%s=[' % _arg_name)
self.indent()
for item in _arg_value:
arg_string, arg_imports = MigrationWriter.serialize(item)
@@ -53,22 +53,22 @@ class OperationWriter:
if len(args) > 1:
for arg in args[:-1]:
self.feed(arg)
self.feed("%s," % args[-1])
self.feed('%s,' % args[-1])
else:
self.feed("%s," % arg_string)
self.feed('%s,' % arg_string)
imports.update(arg_imports)
self.unindent()
self.feed("],")
self.feed('],')
else:
arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
args = arg_string.splitlines()
if len(args) > 1:
self.feed("%s=%s" % (_arg_name, args[0]))
self.feed('%s=%s' % (_arg_name, args[0]))
for arg in args[1:-1]:
self.feed(arg)
self.feed("%s," % args[-1])
self.feed('%s,' % args[-1])
else:
self.feed("%s=%s," % (_arg_name, arg_string))
self.feed('%s=%s,' % (_arg_name, arg_string))
imports.update(arg_imports)
imports = set()
@@ -79,10 +79,10 @@ class OperationWriter:
# We can just use the fact we already have that imported,
# otherwise, we need to add an import for the operation class.
if getattr(migrations, name, None) == self.operation.__class__:
self.feed("migrations.%s(" % name)
self.feed('migrations.%s(' % name)
else:
imports.add("import %s" % (self.operation.__class__.__module__))
self.feed("%s.%s(" % (self.operation.__class__.__module__, name))
imports.add('import %s' % (self.operation.__class__.__module__))
self.feed('%s.%s(' % (self.operation.__class__.__module__, name))
self.indent()
@@ -99,7 +99,7 @@ class OperationWriter:
_write(arg_name, arg_value)
self.unindent()
self.feed("),")
self.feed('),')
return self.render(), imports
def indent(self):
@@ -109,10 +109,10 @@ class OperationWriter:
self.indentation -= 1
def feed(self, line):
self.buff.append(" " * (self.indentation * 4) + line)
self.buff.append(' ' * (self.indentation * 4) + line)
def render(self):
return "\n".join(self.buff)
return '\n'.join(self.buff)
class MigrationWriter:
@@ -147,10 +147,7 @@ class MigrationWriter:
dependencies = []
for dependency in self.migration.dependencies:
if dependency[0] == "__setting__":
dependencies.append(
" migrations.swappable_dependency(settings.%s),"
% dependency[1]
)
dependencies.append(" migrations.swappable_dependency(settings.%s)," % dependency[1])
imports.add("from django.conf import settings")
else:
dependencies.append(" %s," % self.serialize(dependency)[0])
@@ -186,28 +183,24 @@ class MigrationWriter:
) % "\n# ".join(sorted(migration_imports))
# If there's a replaces, make a string for it
if self.migration.replaces:
items["replaces_str"] = (
"\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
)
items['replaces_str'] = "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
# Hinting that goes into comment
if self.include_header:
items["migration_header"] = MIGRATION_HEADER_TEMPLATE % {
"version": get_version(),
"timestamp": now().strftime("%Y-%m-%d %H:%M"),
items['migration_header'] = MIGRATION_HEADER_TEMPLATE % {
'version': get_version(),
'timestamp': now().strftime("%Y-%m-%d %H:%M"),
}
else:
items["migration_header"] = ""
items['migration_header'] = ""
if self.migration.initial:
items["initial_str"] = "\n initial = True\n"
items['initial_str'] = "\n initial = True\n"
return MIGRATION_TEMPLATE % items
@property
def basedir(self):
migrations_package_name, _ = MigrationLoader.migrations_module(
self.migration.app_label
)
migrations_package_name, _ = MigrationLoader.migrations_module(self.migration.app_label)
if migrations_package_name is None:
raise ValueError(
@@ -229,11 +222,7 @@ class MigrationWriter:
# Alright, see if it's a direct submodule of the app
app_config = apps.get_app_config(self.migration.app_label)
(
maybe_app_name,
_,
migrations_package_basename,
) = migrations_package_name.rpartition(".")
maybe_app_name, _, migrations_package_basename = migrations_package_name.rpartition(".")
if app_config.name == maybe_app_name:
return os.path.join(app_config.path, migrations_package_basename)
@@ -257,8 +246,8 @@ class MigrationWriter:
raise ValueError(
"Could not locate an appropriate location to create "
"migrations package %s. Make sure the toplevel "
"package exists and can be imported." % migrations_package_name
)
"package exists and can be imported." %
migrations_package_name)
final_dir = os.path.join(base_dir, *missing_dirs)
os.makedirs(final_dir, exist_ok=True)