测试gitnore

This commit is contained in:
ladeng07
2022-05-06 15:45:57 +08:00
parent 12f390949b
commit 51552904f9
2347 changed files with 120102 additions and 53549 deletions
+12 -29
View File
@@ -1,38 +1,21 @@
"""Django Unit Test framework."""
from django.test.client import AsyncClient, AsyncRequestFactory, Client, RequestFactory
from django.test.client import (
AsyncClient, AsyncRequestFactory, Client, RequestFactory,
)
from django.test.testcases import (
LiveServerTestCase,
SimpleTestCase,
TestCase,
TransactionTestCase,
skipIfDBFeature,
skipUnlessAnyDBFeature,
skipUnlessDBFeature,
LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase,
skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature,
)
from django.test.utils import (
ignore_warnings,
modify_settings,
override_settings,
override_system_checks,
tag,
ignore_warnings, modify_settings, override_settings,
override_system_checks, tag,
)
__all__ = [
"AsyncClient",
"AsyncRequestFactory",
"Client",
"RequestFactory",
"TestCase",
"TransactionTestCase",
"SimpleTestCase",
"LiveServerTestCase",
"skipIfDBFeature",
"skipUnlessAnyDBFeature",
"skipUnlessDBFeature",
"ignore_warnings",
"modify_settings",
"override_settings",
"override_system_checks",
"tag",
'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory',
'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase',
'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature',
'ignore_warnings', 'modify_settings', 'override_settings',
'override_system_checks', 'tag',
]
+213 -377
View File
@@ -16,7 +16,9 @@ from django.core.handlers.asgi import ASGIRequest
from django.core.handlers.base import BaseHandler
from django.core.handlers.wsgi import WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder
from django.core.signals import got_request_exception, request_finished, request_started
from django.core.signals import (
got_request_exception, request_finished, request_started,
)
from django.db import close_old_connections
from django.http import HttpRequest, QueryDict, SimpleCookie
from django.test import signals
@@ -29,26 +31,20 @@ from django.utils.itercompat import is_iterable
from django.utils.regex_helper import _lazy_re_compile
__all__ = (
"AsyncClient",
"AsyncRequestFactory",
"Client",
"RedirectCycleError",
"RequestFactory",
"encode_file",
"encode_multipart",
'AsyncClient', 'AsyncRequestFactory', 'Client', 'RedirectCycleError',
'RequestFactory', 'encode_file', 'encode_multipart',
)
BOUNDARY = "BoUnDaRyStRiNg"
MULTIPART_CONTENT = "multipart/form-data; boundary=%s" % BOUNDARY
CONTENT_TYPE_RE = _lazy_re_compile(r".*; charset=([\w\d-]+);?")
BOUNDARY = 'BoUnDaRyStRiNg'
MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
CONTENT_TYPE_RE = _lazy_re_compile(r'.*; charset=([\w\d-]+);?')
# Structured suffix spec: https://tools.ietf.org/html/rfc6838#section-4.2.8
JSON_CONTENT_TYPE_RE = _lazy_re_compile(r"^application\/(.+\+)?json")
JSON_CONTENT_TYPE_RE = _lazy_re_compile(r'^application\/(.+\+)?json')
class RedirectCycleError(Exception):
"""The test client has been asked to follow a redirect loop."""
def __init__(self, message, last_response):
super().__init__(message)
self.last_response = last_response
@@ -62,7 +58,6 @@ class FakePayload:
length. This makes sure that views can't do anything under the test client
that wouldn't work in real life.
"""
def __init__(self, content=None):
self.__content = BytesIO()
self.__len = 0
@@ -79,9 +74,7 @@ class FakePayload:
self.read_started = True
if num_bytes is None:
num_bytes = self.__len or 0
assert (
self.__len >= num_bytes
), "Cannot read more than the available bytes from the HTTP incoming data."
assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data."
content = self.__content.read(num_bytes)
self.__len -= num_bytes
return content
@@ -99,13 +92,13 @@ def closing_iterator_wrapper(iterable, close):
yield from iterable
finally:
request_finished.disconnect(close_old_connections)
close() # will fire request_finished
close() # will fire request_finished
request_finished.connect(close_old_connections)
def conditional_content_removal(request, response):
"""
Simulate the behavior of most web servers by removing the content of
Simulate the behavior of most Web servers by removing the content of
responses for HEAD requests, 1xx, 204, and 304 responses. Ensure
compliance with RFC 7230, section 3.3.3.
"""
@@ -113,22 +106,21 @@ def conditional_content_removal(request, response):
if response.streaming:
response.streaming_content = []
else:
response.content = b""
if request.method == "HEAD":
response.content = b''
if request.method == 'HEAD':
if response.streaming:
response.streaming_content = []
else:
response.content = b""
response.content = b''
return response
class ClientHandler(BaseHandler):
"""
An HTTP Handler that can be used for testing purposes. Use the WSGI
A HTTP Handler that can be used for testing purposes. Use the WSGI
interface to compose requests, but return the raw HttpResponse object with
the originating WSGIRequest attached to its ``wsgi_request`` attribute.
"""
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
self.enforce_csrf_checks = enforce_csrf_checks
super().__init__(*args, **kwargs)
@@ -152,7 +144,7 @@ class ClientHandler(BaseHandler):
# Request goes through middleware.
response = self.get_response(request)
# Simulate behaviors of most web servers.
# Simulate behaviors of most Web servers.
conditional_content_removal(request, response)
# Attach the originating request to the response so that it could be
@@ -162,11 +154,10 @@ class ClientHandler(BaseHandler):
# Emulate a WSGI server by calling the close method on completion.
if response.streaming:
response.streaming_content = closing_iterator_wrapper(
response.streaming_content, response.close
)
response.streaming_content, response.close)
else:
request_finished.disconnect(close_old_connections)
response.close() # will fire request_finished
response.close() # will fire request_finished
request_finished.connect(close_old_connections)
return response
@@ -174,7 +165,6 @@ class ClientHandler(BaseHandler):
class AsyncClientHandler(BaseHandler):
"""An async version of ClientHandler."""
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
self.enforce_csrf_checks = enforce_csrf_checks
super().__init__(*args, **kwargs)
@@ -185,15 +175,13 @@ class AsyncClientHandler(BaseHandler):
if self._middleware_chain is None:
self.load_middleware(is_async=True)
# Extract body file from the scope, if provided.
if "_body_file" in scope:
body_file = scope.pop("_body_file")
if '_body_file' in scope:
body_file = scope.pop('_body_file')
else:
body_file = FakePayload("")
body_file = FakePayload('')
request_started.disconnect(close_old_connections)
await sync_to_async(request_started.send, thread_sensitive=False)(
sender=self.__class__, scope=scope
)
await sync_to_async(request_started.send, thread_sensitive=False)(sender=self.__class__, scope=scope)
request_started.connect(close_old_connections)
request = ASGIRequest(scope, body_file)
# Sneaky little hack so that we can easily get round
@@ -202,16 +190,14 @@ class AsyncClientHandler(BaseHandler):
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
# Request goes through middleware.
response = await self.get_response_async(request)
# Simulate behaviors of most web servers.
# Simulate behaviors of most Web servers.
conditional_content_removal(request, response)
# Attach the originating ASGI request to the response so that it could
# be later retrieved.
response.asgi_request = request
# Emulate a server by calling the close method on completion.
if response.streaming:
response.streaming_content = await sync_to_async(
closing_iterator_wrapper, thread_sensitive=False
)(
response.streaming_content = await sync_to_async(closing_iterator_wrapper, thread_sensitive=False)(
response.streaming_content,
response.close,
)
@@ -230,10 +216,10 @@ def store_rendered_templates(store, signal, sender, template, context, **kwargs)
The context is copied so that it is an accurate representation at the time
of rendering.
"""
store.setdefault("templates", []).append(template)
if "context" not in store:
store["context"] = ContextList()
store["context"].append(copy(context))
store.setdefault('templates', []).append(template)
if 'context' not in store:
store['context'] = ContextList()
store['context'].append(copy(context))
def encode_multipart(boundary, data):
@@ -269,33 +255,25 @@ def encode_multipart(boundary, data):
if is_file(item):
lines.extend(encode_file(boundary, key, item))
else:
lines.extend(
to_bytes(val)
for val in [
"--%s" % boundary,
'Content-Disposition: form-data; name="%s"' % key,
"",
item,
]
)
lines.extend(to_bytes(val) for val in [
'--%s' % boundary,
'Content-Disposition: form-data; name="%s"' % key,
'',
item
])
else:
lines.extend(
to_bytes(val)
for val in [
"--%s" % boundary,
'Content-Disposition: form-data; name="%s"' % key,
"",
value,
]
)
lines.extend(to_bytes(val) for val in [
'--%s' % boundary,
'Content-Disposition: form-data; name="%s"' % key,
'',
value
])
lines.extend(
[
to_bytes("--%s--" % boundary),
b"",
]
)
return b"\r\n".join(lines)
lines.extend([
to_bytes('--%s--' % boundary),
b'',
])
return b'\r\n'.join(lines)
def encode_file(boundary, key, file):
@@ -304,10 +282,10 @@ def encode_file(boundary, key, file):
# file.name might not be a string. For example, it's an int for
# tempfile.TemporaryFile().
file_has_string_name = hasattr(file, "name") and isinstance(file.name, str)
filename = os.path.basename(file.name) if file_has_string_name else ""
file_has_string_name = hasattr(file, 'name') and isinstance(file.name, str)
filename = os.path.basename(file.name) if file_has_string_name else ''
if hasattr(file, "content_type"):
if hasattr(file, 'content_type'):
content_type = file.content_type
elif filename:
content_type = mimetypes.guess_type(filename)[0]
@@ -315,16 +293,15 @@ def encode_file(boundary, key, file):
content_type = None
if content_type is None:
content_type = "application/octet-stream"
content_type = 'application/octet-stream'
filename = filename or key
return [
to_bytes("--%s" % boundary),
to_bytes(
'Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename)
),
to_bytes("Content-Type: %s" % content_type),
b"",
to_bytes(file.read()),
to_bytes('--%s' % boundary),
to_bytes('Content-Disposition: form-data; name="%s"; filename="%s"'
% (key, filename)),
to_bytes('Content-Type: %s' % content_type),
b'',
to_bytes(file.read())
]
@@ -341,7 +318,6 @@ class RequestFactory:
Once you have a request object you can pass it to any view function,
just as if that view had been hooked up using a URLconf.
"""
def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults):
self.json_encoder = json_encoder
self.defaults = defaults
@@ -357,26 +333,24 @@ class RequestFactory:
# - REMOTE_ADDR: often useful, see #8551.
# See https://www.python.org/dev/peps/pep-3333/#environ-variables
return {
"HTTP_COOKIE": "; ".join(
sorted(
"%s=%s" % (morsel.key, morsel.coded_value)
for morsel in self.cookies.values()
)
),
"PATH_INFO": "/",
"REMOTE_ADDR": "127.0.0.1",
"REQUEST_METHOD": "GET",
"SCRIPT_NAME": "",
"SERVER_NAME": "testserver",
"SERVER_PORT": "80",
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.version": (1, 0),
"wsgi.url_scheme": "http",
"wsgi.input": FakePayload(b""),
"wsgi.errors": self.errors,
"wsgi.multiprocess": True,
"wsgi.multithread": False,
"wsgi.run_once": False,
'HTTP_COOKIE': '; '.join(sorted(
'%s=%s' % (morsel.key, morsel.coded_value)
for morsel in self.cookies.values()
)),
'PATH_INFO': '/',
'REMOTE_ADDR': '127.0.0.1',
'REQUEST_METHOD': 'GET',
'SCRIPT_NAME': '',
'SERVER_NAME': 'testserver',
'SERVER_PORT': '80',
'SERVER_PROTOCOL': 'HTTP/1.1',
'wsgi.version': (1, 0),
'wsgi.url_scheme': 'http',
'wsgi.input': FakePayload(b''),
'wsgi.errors': self.errors,
'wsgi.multiprocess': True,
'wsgi.multithread': False,
'wsgi.run_once': False,
**self.defaults,
**request,
}
@@ -402,9 +376,7 @@ class RequestFactory:
Return encoded JSON if data is a dict, list, or tuple and content_type
is application/json.
"""
should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(
data, (dict, list, tuple)
)
should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(data, (dict, list, tuple))
return json.dumps(data, cls=self.json_encoder) if should_encode else data
def _get_path(self, parsed):
@@ -416,128 +388,88 @@ class RequestFactory:
# Replace the behavior where non-ASCII values in the WSGI environ are
# arbitrarily decoded with ISO-8859-1.
# Refs comment in `get_bytes_from_wsgi()`.
return path.decode("iso-8859-1")
return path.decode('iso-8859-1')
def get(self, path, data=None, secure=False, **extra):
"""Construct a GET request."""
data = {} if data is None else data
return self.generic(
"GET",
path,
secure=secure,
**{
"QUERY_STRING": urlencode(data, doseq=True),
**extra,
},
)
return self.generic('GET', path, secure=secure, **{
'QUERY_STRING': urlencode(data, doseq=True),
**extra,
})
def post(
self, path, data=None, content_type=MULTIPART_CONTENT, secure=False, **extra
):
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
secure=False, **extra):
"""Construct a POST request."""
data = self._encode_json({} if data is None else data, content_type)
post_data = self._encode_data(data, content_type)
return self.generic(
"POST", path, post_data, content_type, secure=secure, **extra
)
return self.generic('POST', path, post_data, content_type,
secure=secure, **extra)
def head(self, path, data=None, secure=False, **extra):
"""Construct a HEAD request."""
data = {} if data is None else data
return self.generic(
"HEAD",
path,
secure=secure,
**{
"QUERY_STRING": urlencode(data, doseq=True),
**extra,
},
)
return self.generic('HEAD', path, secure=secure, **{
'QUERY_STRING': urlencode(data, doseq=True),
**extra,
})
def trace(self, path, secure=False, **extra):
"""Construct a TRACE request."""
return self.generic("TRACE", path, secure=secure, **extra)
return self.generic('TRACE', path, secure=secure, **extra)
def options(
self,
path,
data="",
content_type="application/octet-stream",
secure=False,
**extra,
):
def options(self, path, data='', content_type='application/octet-stream',
secure=False, **extra):
"Construct an OPTIONS request."
return self.generic("OPTIONS", path, data, content_type, secure=secure, **extra)
return self.generic('OPTIONS', path, data, content_type,
secure=secure, **extra)
def put(
self,
path,
data="",
content_type="application/octet-stream",
secure=False,
**extra,
):
def put(self, path, data='', content_type='application/octet-stream',
secure=False, **extra):
"""Construct a PUT request."""
data = self._encode_json(data, content_type)
return self.generic("PUT", path, data, content_type, secure=secure, **extra)
return self.generic('PUT', path, data, content_type,
secure=secure, **extra)
def patch(
self,
path,
data="",
content_type="application/octet-stream",
secure=False,
**extra,
):
def patch(self, path, data='', content_type='application/octet-stream',
secure=False, **extra):
"""Construct a PATCH request."""
data = self._encode_json(data, content_type)
return self.generic("PATCH", path, data, content_type, secure=secure, **extra)
return self.generic('PATCH', path, data, content_type,
secure=secure, **extra)
def delete(
self,
path,
data="",
content_type="application/octet-stream",
secure=False,
**extra,
):
def delete(self, path, data='', content_type='application/octet-stream',
secure=False, **extra):
"""Construct a DELETE request."""
data = self._encode_json(data, content_type)
return self.generic("DELETE", path, data, content_type, secure=secure, **extra)
return self.generic('DELETE', path, data, content_type,
secure=secure, **extra)
def generic(
self,
method,
path,
data="",
content_type="application/octet-stream",
secure=False,
**extra,
):
def generic(self, method, path, data='',
content_type='application/octet-stream', secure=False,
**extra):
"""Construct an arbitrary HTTP request."""
parsed = urlparse(str(path)) # path can be lazy
data = force_bytes(data, settings.DEFAULT_CHARSET)
r = {
"PATH_INFO": self._get_path(parsed),
"REQUEST_METHOD": method,
"SERVER_PORT": "443" if secure else "80",
"wsgi.url_scheme": "https" if secure else "http",
'PATH_INFO': self._get_path(parsed),
'REQUEST_METHOD': method,
'SERVER_PORT': '443' if secure else '80',
'wsgi.url_scheme': 'https' if secure else 'http',
}
if data:
r.update(
{
"CONTENT_LENGTH": str(len(data)),
"CONTENT_TYPE": content_type,
"wsgi.input": FakePayload(data),
}
)
r.update({
'CONTENT_LENGTH': str(len(data)),
'CONTENT_TYPE': content_type,
'wsgi.input': FakePayload(data),
})
r.update(extra)
# If QUERY_STRING is absent or empty, we want to extract it from the URL.
if not r.get("QUERY_STRING"):
if not r.get('QUERY_STRING'):
# WSGI requires latin-1 encoded strings. See get_path_info().
query_string = parsed[4].encode().decode("iso-8859-1")
r["QUERY_STRING"] = query_string
query_string = parsed[4].encode().decode('iso-8859-1')
r['QUERY_STRING'] = query_string
return self.request(**r)
@@ -555,35 +487,30 @@ class AsyncRequestFactory(RequestFactory):
a) this makes ASGIRequest subclasses, and
b) AsyncTestClient can subclass it.
"""
def _base_scope(self, **request):
"""The base scope for a request."""
# This is a minimal valid ASGI scope, plus:
# - headers['cookie'] for cookie support,
# - 'client' often useful, see #8551.
scope = {
"asgi": {"version": "3.0"},
"type": "http",
"http_version": "1.1",
"client": ["127.0.0.1", 0],
"server": ("testserver", "80"),
"scheme": "http",
"method": "GET",
"headers": [],
'asgi': {'version': '3.0'},
'type': 'http',
'http_version': '1.1',
'client': ['127.0.0.1', 0],
'server': ('testserver', '80'),
'scheme': 'http',
'method': 'GET',
'headers': [],
**self.defaults,
**request,
}
scope["headers"].append(
(
b"cookie",
b"; ".join(
sorted(
("%s=%s" % (morsel.key, morsel.coded_value)).encode("ascii")
for morsel in self.cookies.values()
)
),
)
)
scope['headers'].append((
b'cookie',
b'; '.join(sorted(
('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
for morsel in self.cookies.values()
)),
))
return scope
def request(self, **request):
@@ -591,52 +518,43 @@ class AsyncRequestFactory(RequestFactory):
# This is synchronous, which means all methods on this class are.
# AsyncClient, however, has an async request function, which makes all
# its methods async.
if "_body_file" in request:
body_file = request.pop("_body_file")
if '_body_file' in request:
body_file = request.pop('_body_file')
else:
body_file = FakePayload("")
body_file = FakePayload('')
return ASGIRequest(self._base_scope(**request), body_file)
def generic(
self,
method,
path,
data="",
content_type="application/octet-stream",
secure=False,
**extra,
self, method, path, data='', content_type='application/octet-stream',
secure=False, **extra,
):
"""Construct an arbitrary HTTP request."""
parsed = urlparse(str(path)) # path can be lazy.
data = force_bytes(data, settings.DEFAULT_CHARSET)
s = {
"method": method,
"path": self._get_path(parsed),
"server": ("127.0.0.1", "443" if secure else "80"),
"scheme": "https" if secure else "http",
"headers": [(b"host", b"testserver")],
'method': method,
'path': self._get_path(parsed),
'server': ('127.0.0.1', '443' if secure else '80'),
'scheme': 'https' if secure else 'http',
'headers': [(b'host', b'testserver')],
}
if data:
s["headers"].extend(
[
(b"content-length", str(len(data)).encode("ascii")),
(b"content-type", content_type.encode("ascii")),
]
)
s["_body_file"] = FakePayload(data)
follow = extra.pop("follow", None)
s['headers'].extend([
(b'content-length', str(len(data)).encode('ascii')),
(b'content-type', content_type.encode('ascii')),
])
s['_body_file'] = FakePayload(data)
follow = extra.pop('follow', None)
if follow is not None:
s["follow"] = follow
if query_string := extra.pop("QUERY_STRING", None):
s["query_string"] = query_string
s["headers"] += [
(key.lower().encode("ascii"), value.encode("latin1"))
s['follow'] = follow
s['headers'] += [
(key.lower().encode('ascii'), value.encode('latin1'))
for key, value in extra.items()
]
# If QUERY_STRING is absent or empty, we want to extract it from the
# URL.
if not s.get("query_string"):
s["query_string"] = parsed[4]
if not s.get('query_string'):
s['query_string'] = parsed[4]
return self.request(**s)
@@ -644,7 +562,6 @@ class ClientMixin:
"""
Mixin with common methods between Client and AsyncClient.
"""
def store_exc_info(self, **kwargs):
"""Store exceptions when they are generated by a view."""
self.exc_info = sys.exc_info()
@@ -682,7 +599,6 @@ class ClientMixin:
are incorrect.
"""
from django.contrib.auth import authenticate
user = authenticate(**credentials)
if user:
self._login(user)
@@ -692,10 +608,9 @@ class ClientMixin:
def force_login(self, user, backend=None):
def get_backend():
from django.contrib.auth import load_backend
for backend_path in settings.AUTHENTICATION_BACKENDS:
backend = load_backend(backend_path)
if hasattr(backend, "get_user"):
if hasattr(backend, 'get_user'):
return backend_path
if backend is None:
@@ -720,18 +635,17 @@ class ClientMixin:
session_cookie = settings.SESSION_COOKIE_NAME
self.cookies[session_cookie] = request.session.session_key
cookie_data = {
"max-age": None,
"path": "/",
"domain": settings.SESSION_COOKIE_DOMAIN,
"secure": settings.SESSION_COOKIE_SECURE or None,
"expires": None,
'max-age': None,
'path': '/',
'domain': settings.SESSION_COOKIE_DOMAIN,
'secure': settings.SESSION_COOKIE_SECURE or None,
'expires': None,
}
self.cookies[session_cookie].update(cookie_data)
def logout(self):
"""Log out the user by removing the cookies and session object."""
from django.contrib.auth import get_user, logout
request = HttpRequest()
if self.session:
request.session = self.session
@@ -743,15 +657,13 @@ class ClientMixin:
self.cookies = SimpleCookie()
def _parse_json(self, response, **extra):
if not hasattr(response, "_json"):
if not JSON_CONTENT_TYPE_RE.match(response.get("Content-Type")):
if not hasattr(response, '_json'):
if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
raise ValueError(
'Content-Type header is "%s", not "application/json"'
% response.get("Content-Type")
% response.get('Content-Type')
)
response._json = json.loads(
response.content.decode(response.charset), **extra
)
response._json = json.loads(response.content.decode(response.charset), **extra)
return response._json
@@ -773,10 +685,7 @@ class Client(ClientMixin, RequestFactory):
contexts and templates produced by a view, rather than the
HTML rendered to the end-user.
"""
def __init__(
self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
):
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
super().__init__(**defaults)
self.handler = ClientHandler(enforce_csrf_checks)
self.raise_request_exception = raise_request_exception
@@ -812,14 +721,11 @@ class Client(ClientMixin, RequestFactory):
response.client = self
response.request = request
# Add any rendered template detail to the response.
response.templates = data.get("templates", [])
response.context = data.get("context")
response.templates = data.get('templates', [])
response.context = data.get('context')
response.json = partial(self._parse_json, response)
# Attach the ResolverMatch instance to the response.
urlconf = getattr(response.wsgi_request, "urlconf", None)
response.resolver_match = SimpleLazyObject(
lambda: resolve(request["PATH_INFO"], urlconf=urlconf),
)
response.resolver_match = SimpleLazyObject(lambda: resolve(request['PATH_INFO']))
# Flatten a single context. Not really necessary anymore thanks to the
# __getattr__ flattening in ContextList, but has some edge case
# backwards compatibility implications.
@@ -838,24 +744,13 @@ class Client(ClientMixin, RequestFactory):
response = self._handle_redirects(response, data=data, **extra)
return response
def post(
self,
path,
data=None,
content_type=MULTIPART_CONTENT,
follow=False,
secure=False,
**extra,
):
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
follow=False, secure=False, **extra):
"""Request a response from the server using POST."""
self.extra = extra
response = super().post(
path, data=data, content_type=content_type, secure=secure, **extra
)
response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(
response, data=data, content_type=content_type, **extra
)
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def head(self, path, data=None, follow=False, secure=False, **extra):
@@ -866,87 +761,43 @@ class Client(ClientMixin, RequestFactory):
response = self._handle_redirects(response, data=data, **extra)
return response
def options(
self,
path,
data="",
content_type="application/octet-stream",
follow=False,
secure=False,
**extra,
):
def options(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra):
"""Request a response from the server using OPTIONS."""
self.extra = extra
response = super().options(
path, data=data, content_type=content_type, secure=secure, **extra
)
response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(
response, data=data, content_type=content_type, **extra
)
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def put(
self,
path,
data="",
content_type="application/octet-stream",
follow=False,
secure=False,
**extra,
):
def put(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra):
"""Send a resource to the server using PUT."""
self.extra = extra
response = super().put(
path, data=data, content_type=content_type, secure=secure, **extra
)
response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(
response, data=data, content_type=content_type, **extra
)
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def patch(
self,
path,
data="",
content_type="application/octet-stream",
follow=False,
secure=False,
**extra,
):
def patch(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra):
"""Send a resource to the server using PATCH."""
self.extra = extra
response = super().patch(
path, data=data, content_type=content_type, secure=secure, **extra
)
response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(
response, data=data, content_type=content_type, **extra
)
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def delete(
self,
path,
data="",
content_type="application/octet-stream",
follow=False,
secure=False,
**extra,
):
def delete(self, path, data='', content_type='application/octet-stream',
follow=False, secure=False, **extra):
"""Send a DELETE request to the server."""
self.extra = extra
response = super().delete(
path, data=data, content_type=content_type, secure=secure, **extra
)
response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
if follow:
response = self._handle_redirects(
response, data=data, content_type=content_type, **extra
)
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
return response
def trace(self, path, data="", follow=False, secure=False, **extra):
def trace(self, path, data='', follow=False, secure=False, **extra):
"""Send a TRACE request to the server."""
self.extra = extra
response = super().trace(path, data=data, secure=secure, **extra)
@@ -954,7 +805,7 @@ class Client(ClientMixin, RequestFactory):
response = self._handle_redirects(response, data=data, **extra)
return response
def _handle_redirects(self, response, data="", content_type="", **extra):
def _handle_redirects(self, response, data='', content_type='', **extra):
"""
Follow any redirects by requesting responses from the server using GET.
"""
@@ -973,46 +824,36 @@ class Client(ClientMixin, RequestFactory):
url = urlsplit(response_url)
if url.scheme:
extra["wsgi.url_scheme"] = url.scheme
extra['wsgi.url_scheme'] = url.scheme
if url.hostname:
extra["SERVER_NAME"] = url.hostname
extra['SERVER_NAME'] = url.hostname
if url.port:
extra["SERVER_PORT"] = str(url.port)
extra['SERVER_PORT'] = str(url.port)
path = url.path
# RFC 2616: bare domains without path are treated as the root.
if not path and url.netloc:
path = "/"
# Prepend the request path to handle relative path redirects
if not path.startswith("/"):
path = urljoin(response.request["PATH_INFO"], path)
path = url.path
if not path.startswith('/'):
path = urljoin(response.request['PATH_INFO'], path)
if response.status_code in (
HTTPStatus.TEMPORARY_REDIRECT,
HTTPStatus.PERMANENT_REDIRECT,
):
if response.status_code in (HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT):
# Preserve request method and query string (if needed)
# post-redirect for 307/308 responses.
request_method = response.request["REQUEST_METHOD"].lower()
if request_method not in ("get", "head"):
extra["QUERY_STRING"] = url.query
request_method = response.request['REQUEST_METHOD'].lower()
if request_method not in ('get', 'head'):
extra['QUERY_STRING'] = url.query
request_method = getattr(self, request_method)
else:
request_method = self.get
data = QueryDict(url.query)
content_type = None
response = request_method(
path, data=data, content_type=content_type, follow=False, **extra
)
response = request_method(path, data=data, content_type=content_type, follow=False, **extra)
response.redirect_chain = redirect_chain
if redirect_chain[-1] in redirect_chain[:-1]:
# Check that we're not redirecting to somewhere we've already
# been to, to prevent loops.
raise RedirectCycleError(
"Redirect loop detected.", last_response=response
)
raise RedirectCycleError("Redirect loop detected.", last_response=response)
if len(redirect_chain) > 20:
# Such a lengthy chain likely also means a loop, but one with
# a growing path, changing view, or changing query argument;
@@ -1029,10 +870,7 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
Does not currently support "follow" on its methods.
"""
def __init__(
self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
):
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
super().__init__(**defaults)
self.handler = AsyncClientHandler(enforce_csrf_checks)
self.raise_request_exception = raise_request_exception
@@ -1046,19 +884,20 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
query environment, which can be overridden using the arguments to the
request.
"""
if "follow" in request:
if 'follow' in request:
raise NotImplementedError(
"AsyncClient request methods do not accept the follow parameter."
'AsyncClient request methods do not accept the follow '
'parameter.'
)
scope = self._base_scope(**request)
# Curry a data dictionary into an instance of the template renderer
# callback function.
data = {}
on_template_render = partial(store_rendered_templates, data)
signal_uid = "template-render-%s" % id(request)
signal_uid = 'template-render-%s' % id(request)
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
# Capture exceptions created by the handler.
exception_uid = "request-exception-%s" % id(request)
exception_uid = 'request-exception-%s' % id(request)
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
try:
response = await self.handler(scope)
@@ -1071,14 +910,11 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
response.client = self
response.request = request
# Add any rendered template detail to the response.
response.templates = data.get("templates", [])
response.context = data.get("context")
response.templates = data.get('templates', [])
response.context = data.get('context')
response.json = partial(self._parse_json, response)
# Attach the ResolverMatch instance to the response.
urlconf = getattr(response.asgi_request, "urlconf", None)
response.resolver_match = SimpleLazyObject(
lambda: resolve(request["path"], urlconf=urlconf),
)
response.resolver_match = SimpleLazyObject(lambda: resolve(request['path']))
# Flatten a single context. Not really necessary anymore thanks to the
# __getattr__ flattening in ContextList, but has some edge case
# backwards compatibility implications.
+43 -87
View File
@@ -7,62 +7,11 @@ from django.utils.regex_helper import _lazy_re_compile
# ASCII whitespace is U+0009 TAB, U+000A LF, U+000C FF, U+000D CR, or U+0020
# SPACE.
# https://infra.spec.whatwg.org/#ascii-whitespace
ASCII_WHITESPACE = _lazy_re_compile(r"[\t\n\f\r ]+")
# https://html.spec.whatwg.org/#attributes-3
BOOLEAN_ATTRIBUTES = {
"allowfullscreen",
"async",
"autofocus",
"autoplay",
"checked",
"controls",
"default",
"defer ",
"disabled",
"formnovalidate",
"hidden",
"ismap",
"itemscope",
"loop",
"multiple",
"muted",
"nomodule",
"novalidate",
"open",
"playsinline",
"readonly",
"required",
"reversed",
"selected",
# Attributes for deprecated tags.
"truespeed",
}
ASCII_WHITESPACE = _lazy_re_compile(r'[\t\n\f\r ]+')
def normalize_whitespace(string):
return ASCII_WHITESPACE.sub(" ", string)
def normalize_attributes(attributes):
normalized = []
for name, value in attributes:
if name == "class" and value:
# Special case handling of 'class' attribute, so that comparisons
# of DOM instances are not sensitive to ordering of classes.
value = " ".join(
sorted(value for value in ASCII_WHITESPACE.split(value) if value)
)
# Boolean attributes without a value is same as attribute with value
# that equals the attributes name. For example:
# <input checked> == <input checked="checked">
if name in BOOLEAN_ATTRIBUTES:
if not value or value == name:
value = None
elif value is None:
value = ""
normalized.append((name, value))
return normalized
return ASCII_WHITESPACE.sub(' ', string)
class Element:
@@ -100,14 +49,27 @@ class Element:
for i, child in enumerate(self.children):
if isinstance(child, str):
self.children[i] = child.strip()
elif hasattr(child, "finalize"):
elif hasattr(child, 'finalize'):
child.finalize()
def __eq__(self, element):
if not hasattr(element, "name") or self.name != element.name:
if not hasattr(element, 'name') or self.name != element.name:
return False
if len(self.attributes) != len(element.attributes):
return False
if self.attributes != element.attributes:
return False
# attributes without a value is same as attribute with value that
# equals the attributes name:
# <input checked> == <input checked="checked">
for i in range(len(self.attributes)):
attr, value = self.attributes[i]
other_attr, other_value = element.attributes[i]
if value is None:
value = attr
if other_value is None:
other_value = other_attr
if attr != other_attr or value != other_value:
return False
return self.children == element.children
def __hash__(self):
@@ -162,18 +124,18 @@ class Element:
return self.children[key]
def __str__(self):
output = "<%s" % self.name
output = '<%s' % self.name
for key, value in self.attributes:
if value is not None:
if value:
output += ' %s="%s"' % (key, value)
else:
output += " %s" % key
output += ' %s' % key
if self.children:
output += ">\n"
output += "".join(str(c) for c in self.children)
output += "\n</%s>" % self.name
output += '>\n'
output += ''.join(str(c) for c in self.children)
output += '\n</%s>' % self.name
else:
output += ">"
output += '>'
return output
def __repr__(self):
@@ -185,7 +147,7 @@ class RootElement(Element):
super().__init__(None, ())
def __str__(self):
return "".join(str(c) for c in self.children)
return ''.join(str(c) for c in self.children)
class HTMLParseError(Exception):
@@ -195,23 +157,10 @@ class HTMLParseError(Exception):
class Parser(HTMLParser):
# https://html.spec.whatwg.org/#void-elements
SELF_CLOSING_TAGS = {
"area",
"base",
"br",
"col",
"embed",
"hr",
"img",
"input",
"link",
"meta",
"param",
"source",
"track",
"wbr",
'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta',
'param', 'source', 'track', 'wbr',
# Deprecated tags
"frame",
"spacer",
'frame', 'spacer',
}
def __init__(self):
@@ -228,9 +177,9 @@ class Parser(HTMLParser):
position = self.element_positions[element]
if position is None:
position = self.getpos()
if hasattr(position, "lineno"):
if hasattr(position, 'lineno'):
position = position.lineno, position.offset
return "Line %d, Column %d" % position
return 'Line %d, Column %d' % position
@property
def current(self):
@@ -245,7 +194,14 @@ class Parser(HTMLParser):
self.handle_endtag(tag)
def handle_starttag(self, tag, attrs):
attrs = normalize_attributes(attrs)
# Special case handling of 'class' attribute, so that comparisons of DOM
# instances are not sensitive to ordering of classes.
attrs = [
(name, ' '.join(sorted(value for value in ASCII_WHITESPACE.split(value) if value)))
if name == "class"
else (name, value)
for name, value in attrs
]
element = Element(tag, attrs)
self.current.append(element)
if tag not in self.SELF_CLOSING_TAGS:
@@ -254,13 +210,13 @@ class Parser(HTMLParser):
def handle_endtag(self, tag):
if not self.open_tags:
self.error("Unexpected end tag `%s` (%s)" % (tag, self.format_position()))
self.error("Unexpected end tag `%s` (%s)" % (
tag, self.format_position()))
element = self.open_tags.pop()
while element.name != tag:
if not self.open_tags:
self.error(
"Unexpected end tag `%s` (%s)" % (tag, self.format_position())
)
self.error("Unexpected end tag `%s` (%s)" % (
tag, self.format_position()))
element = self.open_tags.pop()
def handle_data(self, data):
File diff suppressed because it is too large Load Diff
+15 -17
View File
@@ -27,9 +27,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
"""
test_class = super().__new__(cls, name, bases, attrs)
# If the test class is either browser-specific or a test base, return it.
if test_class.browser or not any(
name.startswith("test") and callable(value) for name, value in attrs.items()
):
if test_class.browser or not any(name.startswith('test') and callable(value) for name, value in attrs.items()):
return test_class
elif test_class.browsers:
# Reuse the created test class to make it browser-specific.
@@ -39,7 +37,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
first_browser = test_class.browsers[0]
test_class.browser = first_browser
# Listen on an external interface if using a selenium hub.
host = test_class.host if not test_class.selenium_hub else "0.0.0.0"
host = test_class.host if not test_class.selenium_hub else '0.0.0.0'
test_class.host = host
test_class.external_host = cls.external_host
# Create subclasses for each of the remaining browsers and expose
@@ -51,16 +49,16 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
"%s%s" % (capfirst(browser), name),
(test_class,),
{
"browser": browser,
"host": host,
"external_host": cls.external_host,
"__module__": test_class.__module__,
},
'browser': browser,
'host': host,
'external_host': cls.external_host,
'__module__': test_class.__module__,
}
)
setattr(module, browser_test_class.__name__, browser_test_class)
return test_class
# If no browsers were specified, skip this class (it'll still be discovered).
return unittest.skip("No browsers specified.")(test_class)
return unittest.skip('No browsers specified.')(test_class)
@classmethod
def import_webdriver(cls, browser):
@@ -68,12 +66,13 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
@classmethod
def import_options(cls, browser):
return import_string("selenium.webdriver.%s.options.Options" % browser)
return import_string('selenium.webdriver.%s.options.Options' % browser)
@classmethod
def get_capability(cls, browser):
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
from selenium.webdriver.common.desired_capabilities import (
DesiredCapabilities,
)
return getattr(DesiredCapabilities, browser.upper())
def create_options(self):
@@ -88,7 +87,6 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
def create_webdriver(self):
if self.selenium_hub:
from selenium import webdriver
return webdriver.Remote(
command_executor=self.selenium_hub,
desired_capabilities=self.get_capability(self.browser),
@@ -96,14 +94,14 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
return self.import_webdriver(self.browser)(options=self.create_options())
@tag("selenium")
@tag('selenium')
class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
implicit_wait = 10
external_host = None
@classproperty
def live_server_url(cls):
return "http://%s:%s" % (cls.external_host or cls.host, cls.server_thread.port)
return 'http://%s:%s' % (cls.external_host or cls.host, cls.server_thread.port)
@classproperty
def allowed_host(cls):
@@ -120,7 +118,7 @@ class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
# quit() the WebDriver before attempting to terminate and join the
# single-threaded LiveServerThread to avoid a dead lock if the browser
# kept a connection alive.
if hasattr(cls, "selenium"):
if hasattr(cls, 'selenium'):
cls.selenium.quit()
super()._tearDownClassInternal()
+32 -60
View File
@@ -20,14 +20,13 @@ template_rendered = Signal()
# except for cases where the receiver is related to a contrib app.
# Settings that may not work well when using 'override_settings' (#19031)
COMPLEX_OVERRIDE_SETTINGS = {"DATABASES"}
COMPLEX_OVERRIDE_SETTINGS = {'DATABASES'}
@receiver(setting_changed)
def clear_cache_handlers(**kwargs):
if kwargs["setting"] == "CACHES":
if kwargs['setting'] == 'CACHES':
from django.core.cache import caches, close_caches
close_caches()
caches._settings = caches.settings = caches.configure_settings(None)
caches._connections = Local()
@@ -35,41 +34,37 @@ def clear_cache_handlers(**kwargs):
@receiver(setting_changed)
def update_installed_apps(**kwargs):
if kwargs["setting"] == "INSTALLED_APPS":
if kwargs['setting'] == 'INSTALLED_APPS':
# Rebuild any AppDirectoriesFinder instance.
from django.contrib.staticfiles.finders import get_finder
get_finder.cache_clear()
# Rebuild management commands cache
from django.core.management import get_commands
get_commands.cache_clear()
# Rebuild get_app_template_dirs cache.
from django.template.utils import get_app_template_dirs
get_app_template_dirs.cache_clear()
# Rebuild translations cache.
from django.utils.translation import trans_real
trans_real._translations = {}
@receiver(setting_changed)
def update_connections_time_zone(**kwargs):
if kwargs["setting"] == "TIME_ZONE":
if kwargs['setting'] == 'TIME_ZONE':
# Reset process time zone
if hasattr(time, "tzset"):
if kwargs["value"]:
os.environ["TZ"] = kwargs["value"]
if hasattr(time, 'tzset'):
if kwargs['value']:
os.environ['TZ'] = kwargs['value']
else:
os.environ.pop("TZ", None)
os.environ.pop('TZ', None)
time.tzset()
# Reset local time zone cache
timezone.get_default_timezone.cache_clear()
# Reset the database connections' time zone
if kwargs["setting"] in {"TIME_ZONE", "USE_TZ"}:
if kwargs['setting'] in {'TIME_ZONE', 'USE_TZ'}:
for conn in connections.all():
try:
del conn.timezone
@@ -84,19 +79,18 @@ def update_connections_time_zone(**kwargs):
@receiver(setting_changed)
def clear_routers_cache(**kwargs):
if kwargs["setting"] == "DATABASE_ROUTERS":
if kwargs['setting'] == 'DATABASE_ROUTERS':
router.routers = ConnectionRouter().routers
@receiver(setting_changed)
def reset_template_engines(**kwargs):
if kwargs["setting"] in {
"TEMPLATES",
"DEBUG",
"INSTALLED_APPS",
if kwargs['setting'] in {
'TEMPLATES',
'DEBUG',
'INSTALLED_APPS',
}:
from django.template import engines
try:
del engines.templates
except AttributeError:
@@ -104,134 +98,112 @@ def reset_template_engines(**kwargs):
engines._templates = None
engines._engines = {}
from django.template.engine import Engine
Engine.get_default.cache_clear()
from django.forms.renderers import get_default_renderer
get_default_renderer.cache_clear()
@receiver(setting_changed)
def clear_serializers_cache(**kwargs):
if kwargs["setting"] == "SERIALIZATION_MODULES":
if kwargs['setting'] == 'SERIALIZATION_MODULES':
from django.core import serializers
serializers._serializers = {}
@receiver(setting_changed)
def language_changed(**kwargs):
if kwargs["setting"] in {"LANGUAGES", "LANGUAGE_CODE", "LOCALE_PATHS"}:
if kwargs['setting'] in {'LANGUAGES', 'LANGUAGE_CODE', 'LOCALE_PATHS'}:
from django.utils.translation import trans_real
trans_real._default = None
trans_real._active = Local()
if kwargs["setting"] in {"LANGUAGES", "LOCALE_PATHS"}:
if kwargs['setting'] in {'LANGUAGES', 'LOCALE_PATHS'}:
from django.utils.translation import trans_real
trans_real._translations = {}
trans_real.check_for_language.cache_clear()
@receiver(setting_changed)
def localize_settings_changed(**kwargs):
if (
kwargs["setting"] in FORMAT_SETTINGS
or kwargs["setting"] == "USE_THOUSAND_SEPARATOR"
):
if kwargs['setting'] in FORMAT_SETTINGS or kwargs['setting'] == 'USE_THOUSAND_SEPARATOR':
reset_format_cache()
@receiver(setting_changed)
def file_storage_changed(**kwargs):
if kwargs["setting"] == "DEFAULT_FILE_STORAGE":
if kwargs['setting'] == 'DEFAULT_FILE_STORAGE':
from django.core.files.storage import default_storage
default_storage._wrapped = empty
@receiver(setting_changed)
def complex_setting_changed(**kwargs):
if kwargs["enter"] and kwargs["setting"] in COMPLEX_OVERRIDE_SETTINGS:
if kwargs['enter'] and kwargs['setting'] in COMPLEX_OVERRIDE_SETTINGS:
# Considering the current implementation of the signals framework,
# this stacklevel shows the line containing the override_settings call.
warnings.warn(
"Overriding setting %s can lead to unexpected behavior."
% kwargs["setting"],
stacklevel=6,
)
warnings.warn("Overriding setting %s can lead to unexpected behavior."
% kwargs['setting'], stacklevel=6)
@receiver(setting_changed)
def root_urlconf_changed(**kwargs):
if kwargs["setting"] == "ROOT_URLCONF":
if kwargs['setting'] == 'ROOT_URLCONF':
from django.urls import clear_url_caches, set_urlconf
clear_url_caches()
set_urlconf(None)
@receiver(setting_changed)
def static_storage_changed(**kwargs):
if kwargs["setting"] in {
"STATICFILES_STORAGE",
"STATIC_ROOT",
"STATIC_URL",
if kwargs['setting'] in {
'STATICFILES_STORAGE',
'STATIC_ROOT',
'STATIC_URL',
}:
from django.contrib.staticfiles.storage import staticfiles_storage
staticfiles_storage._wrapped = empty
@receiver(setting_changed)
def static_finders_changed(**kwargs):
if kwargs["setting"] in {
"STATICFILES_DIRS",
"STATIC_ROOT",
if kwargs['setting'] in {
'STATICFILES_DIRS',
'STATIC_ROOT',
}:
from django.contrib.staticfiles.finders import get_finder
get_finder.cache_clear()
@receiver(setting_changed)
def auth_password_validators_changed(**kwargs):
if kwargs["setting"] == "AUTH_PASSWORD_VALIDATORS":
if kwargs['setting'] == 'AUTH_PASSWORD_VALIDATORS':
from django.contrib.auth.password_validation import (
get_default_password_validators,
)
get_default_password_validators.cache_clear()
@receiver(setting_changed)
def user_model_swapped(**kwargs):
if kwargs["setting"] == "AUTH_USER_MODEL":
if kwargs['setting'] == 'AUTH_USER_MODEL':
apps.clear_cache()
try:
from django.contrib.auth import get_user_model
UserModel = get_user_model()
except ImproperlyConfigured:
# Some tests set an invalid AUTH_USER_MODEL.
pass
else:
from django.contrib.auth import backends
backends.UserModel = UserModel
from django.contrib.auth import forms
forms.UserModel = UserModel
from django.contrib.auth.handlers import modwsgi
modwsgi.UserModel = UserModel
from django.contrib.auth.management.commands import changepassword
changepassword.UserModel = UserModel
from django.contrib.auth import views
views.UserModel = UserModel
File diff suppressed because it is too large Load Diff
+82 -180
View File
@@ -25,7 +25,6 @@ from django.db.models.options import Options
from django.template import Template
from django.test.signals import setting_changed, template_rendered
from django.urls import get_script_prefix, set_script_prefix
from django.utils.deprecation import RemovedInDjango50Warning
from django.utils.translation import deactivate
try:
@@ -35,24 +34,15 @@ except ImportError:
__all__ = (
"Approximate",
"ContextList",
"isolate_lru_cache",
"get_runner",
"CaptureQueriesContext",
"ignore_warnings",
"isolate_apps",
"modify_settings",
"override_settings",
"override_system_checks",
"tag",
"requires_tz_support",
"setup_databases",
"setup_test_environment",
"teardown_test_environment",
'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner',
'CaptureQueriesContext',
'ignore_warnings', 'isolate_apps', 'modify_settings', 'override_settings',
'override_system_checks', 'tag',
'requires_tz_support',
'setup_databases', 'setup_test_environment', 'teardown_test_environment',
)
TZ_SUPPORT = hasattr(time, "tzset")
TZ_SUPPORT = hasattr(time, 'tzset')
class Approximate:
@@ -72,7 +62,6 @@ class ContextList(list):
A wrapper that provides direct key access to context items contained
in a list of context objects.
"""
def __getitem__(self, key):
if isinstance(key, str):
for subcontext in self:
@@ -120,7 +109,7 @@ def setup_test_environment(debug=None):
Perform global pre-test setup, such as installing the instrumented template
renderer and setting the email backend to the locmem email backend.
"""
if hasattr(_TestState, "saved_data"):
if hasattr(_TestState, 'saved_data'):
# Executing this function twice would overwrite the saved values.
raise RuntimeError(
"setup_test_environment() was already called and can't be called "
@@ -135,13 +124,13 @@ def setup_test_environment(debug=None):
saved_data.allowed_hosts = settings.ALLOWED_HOSTS
# Add the default host of the test client.
settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, "testserver"]
settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver']
saved_data.debug = settings.DEBUG
settings.DEBUG = debug
saved_data.email_backend = settings.EMAIL_BACKEND
settings.EMAIL_BACKEND = "django.core.mail.backends.locmem.EmailBackend"
settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
saved_data.template_render = Template._render
Template._render = instrumented_test_render
@@ -167,18 +156,8 @@ def teardown_test_environment():
del mail.outbox
def setup_databases(
verbosity,
interactive,
*,
time_keeper=None,
keepdb=False,
debug_sql=False,
parallel=0,
aliases=None,
serialized_aliases=None,
**kwargs,
):
def setup_databases(verbosity, interactive, *, time_keeper=None, keepdb=False, debug_sql=False, parallel=0,
aliases=None, **kwargs):
"""Create the test databases."""
if time_keeper is None:
time_keeper = NullTimeKeeper()
@@ -197,31 +176,11 @@ def setup_databases(
if first_alias is None:
first_alias = alias
with time_keeper.timed(" Creating '%s'" % alias):
# RemovedInDjango50Warning: when the deprecation ends,
# replace with:
# serialize_alias = (
# serialized_aliases is None
# or alias in serialized_aliases
# )
try:
serialize_alias = connection.settings_dict["TEST"]["SERIALIZE"]
except KeyError:
serialize_alias = (
serialized_aliases is None or alias in serialized_aliases
)
else:
warnings.warn(
"The SERIALIZE test database setting is "
"deprecated as it can be inferred from the "
"TestCase/TransactionTestCase.databases that "
"enable the serialized_rollback feature.",
category=RemovedInDjango50Warning,
)
connection.creation.create_test_db(
verbosity=verbosity,
autoclobber=not interactive,
keepdb=keepdb,
serialize=serialize_alias,
serialize=connection.settings_dict['TEST'].get('SERIALIZE', True),
)
if parallel > 1:
for index in range(parallel):
@@ -233,15 +192,12 @@ def setup_databases(
)
# Configure all other connections as mirrors of the first one
else:
connections[alias].creation.set_as_test_mirror(
connections[first_alias].settings_dict
)
connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict)
# Configure the test mirrors.
for alias, mirror_alias in mirrored_aliases.items():
connections[alias].creation.set_as_test_mirror(
connections[mirror_alias].settings_dict
)
connections[mirror_alias].settings_dict)
if debug_sql:
for alias in connections:
@@ -250,27 +206,6 @@ def setup_databases(
return old_names
def iter_test_cases(tests):
"""
Return an iterator over a test suite's unittest.TestCase objects.
The tests argument can also be an iterable of TestCase objects.
"""
for test in tests:
if isinstance(test, str):
# Prevent an unfriendly RecursionError that can happen with
# strings.
raise TypeError(
f"Test {test!r} must be a test case or test suite not string "
f"(was found in {tests!r})."
)
if isinstance(test, TestCase):
yield test
else:
# Otherwise, assume it is a test suite.
yield from iter_test_cases(test)
def dependency_ordered(test_databases, dependencies):
"""
Reorder test_databases into an order that honors the dependencies
@@ -334,18 +269,18 @@ def get_unique_databases_and_mirrors(aliases=None):
for alias in connections:
connection = connections[alias]
test_settings = connection.settings_dict["TEST"]
test_settings = connection.settings_dict['TEST']
if test_settings["MIRROR"]:
if test_settings['MIRROR']:
# If the database is marked as a test mirror, save the alias.
mirrored_aliases[alias] = test_settings["MIRROR"]
mirrored_aliases[alias] = test_settings['MIRROR']
elif alias in aliases:
# Store a tuple with DB parameters that uniquely identify it.
# If we have two aliases with the same values for that tuple,
# we only need to create the test database once.
item = test_databases.setdefault(
connection.creation.test_db_signature(),
(connection.settings_dict["NAME"], []),
(connection.settings_dict['NAME'], []),
)
# The default database must be the first because data migrations
# use the default alias by default.
@@ -354,16 +289,11 @@ def get_unique_databases_and_mirrors(aliases=None):
else:
item[1].append(alias)
if "DEPENDENCIES" in test_settings:
dependencies[alias] = test_settings["DEPENDENCIES"]
if 'DEPENDENCIES' in test_settings:
dependencies[alias] = test_settings['DEPENDENCIES']
else:
if (
alias != DEFAULT_DB_ALIAS
and connection.creation.test_db_signature() != default_sig
):
dependencies[alias] = test_settings.get(
"DEPENDENCIES", [DEFAULT_DB_ALIAS]
)
if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig:
dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS])
test_databases = dict(dependency_ordered(test_databases.items(), dependencies))
return test_databases, mirrored_aliases
@@ -385,12 +315,12 @@ def teardown_databases(old_config, verbosity, parallel=0, keepdb=False):
def get_runner(settings, test_runner_class=None):
test_runner_class = test_runner_class or settings.TEST_RUNNER
test_path = test_runner_class.split(".")
test_path = test_runner_class.split('.')
# Allow for relative paths
if len(test_path) > 1:
test_module_name = ".".join(test_path[:-1])
test_module_name = '.'.join(test_path[:-1])
else:
test_module_name = "."
test_module_name = '.'
test_module = __import__(test_module_name, {}, {}, test_path[-1])
return getattr(test_module, test_path[-1])
@@ -407,7 +337,6 @@ class TestContextDecorator:
`kwarg_name`: keyword argument passing the return value of enable() if
used as a function decorator.
"""
def __init__(self, attr_name=None, kwarg_name=None):
self.attr_name = attr_name
self.kwarg_name = kwarg_name
@@ -437,7 +366,7 @@ class TestContextDecorator:
cls.setUp = setUp
return cls
raise TypeError("Can only decorate subclasses of unittest.TestCase")
raise TypeError('Can only decorate subclasses of unittest.TestCase')
def decorate_callable(self, func):
if asyncio.iscoroutinefunction(func):
@@ -449,16 +378,13 @@ class TestContextDecorator:
if self.kwarg_name:
kwargs[self.kwarg_name] = context
return await func(*args, **kwargs)
else:
@wraps(func)
def inner(*args, **kwargs):
with self as context:
if self.kwarg_name:
kwargs[self.kwarg_name] = context
return func(*args, **kwargs)
return inner
def __call__(self, decorated):
@@ -466,7 +392,7 @@ class TestContextDecorator:
return self.decorate_class(decorated)
elif callable(decorated):
return self.decorate_callable(decorated)
raise TypeError("Cannot decorate object of type %s" % type(decorated))
raise TypeError('Cannot decorate object of type %s' % type(decorated))
class override_settings(TestContextDecorator):
@@ -476,7 +402,6 @@ class override_settings(TestContextDecorator):
with the ``with`` statement. In either event, entering/exiting are called
before and after, respectively, the function/block is executed.
"""
enable_exception = None
def __init__(self, **kwargs):
@@ -486,9 +411,9 @@ class override_settings(TestContextDecorator):
def enable(self):
# Keep this code at the beginning to leave the settings unchanged
# in case it raises an exception because INSTALLED_APPS is invalid.
if "INSTALLED_APPS" in self.options:
if 'INSTALLED_APPS' in self.options:
try:
apps.set_installed_apps(self.options["INSTALLED_APPS"])
apps.set_installed_apps(self.options['INSTALLED_APPS'])
except Exception:
apps.unset_installed_apps()
raise
@@ -501,16 +426,14 @@ class override_settings(TestContextDecorator):
try:
setting_changed.send(
sender=settings._wrapped.__class__,
setting=key,
value=new_value,
enter=True,
setting=key, value=new_value, enter=True,
)
except Exception as exc:
self.enable_exception = exc
self.disable()
def disable(self):
if "INSTALLED_APPS" in self.options:
if 'INSTALLED_APPS' in self.options:
apps.unset_installed_apps()
settings._wrapped = self.wrapped
del self.wrapped
@@ -519,9 +442,7 @@ class override_settings(TestContextDecorator):
new_value = getattr(settings, key, None)
responses_for_setting = setting_changed.send_robust(
sender=settings._wrapped.__class__,
setting=key,
value=new_value,
enter=False,
setting=key, value=new_value, enter=False,
)
responses.extend(responses_for_setting)
if self.enable_exception is not None:
@@ -544,12 +465,10 @@ class override_settings(TestContextDecorator):
def decorate_class(self, cls):
from django.test import SimpleTestCase
if not issubclass(cls, SimpleTestCase):
raise ValueError(
"Only subclasses of Django SimpleTestCase can be decorated "
"with override_settings"
)
"with override_settings")
self.save_options(cls)
return cls
@@ -559,7 +478,6 @@ class modify_settings(override_settings):
Like override_settings, but makes it possible to append, prepend, or remove
items instead of redefining the entire list.
"""
def __init__(self, *args, **kwargs):
if args:
# Hack used when instantiating from SimpleTestCase.setUpClass.
@@ -575,9 +493,8 @@ class modify_settings(override_settings):
test_func._modified_settings = self.operations
else:
# Duplicate list to prevent subclasses from altering their parent.
test_func._modified_settings = (
list(test_func._modified_settings) + self.operations
)
test_func._modified_settings = list(
test_func._modified_settings) + self.operations
def enable(self):
self.options = {}
@@ -592,11 +509,11 @@ class modify_settings(override_settings):
# items my be a single value or an iterable.
if isinstance(items, str):
items = [items]
if action == "append":
if action == 'append':
value = value + [item for item in items if item not in value]
elif action == "prepend":
elif action == 'prepend':
value = [item for item in items if item not in value] + value
elif action == "remove":
elif action == 'remove':
value = [item for item in value if item not in items]
else:
raise ValueError("Unsupported action: %s" % action)
@@ -610,10 +527,8 @@ class override_system_checks(TestContextDecorator):
Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app,
you also need to exclude its system checks.
"""
def __init__(self, new_checks, deployment_checks=None):
from django.core.checks.registry import registry
self.registry = registry
self.new_checks = new_checks
self.deployment_checks = deployment_checks
@@ -623,12 +538,12 @@ class override_system_checks(TestContextDecorator):
self.old_checks = self.registry.registered_checks
self.registry.registered_checks = set()
for check in self.new_checks:
self.registry.register(check, *getattr(check, "tags", ()))
self.registry.register(check, *getattr(check, 'tags', ()))
self.old_deployment_checks = self.registry.deployment_checks
if self.deployment_checks is not None:
self.registry.deployment_checks = set()
for check in self.deployment_checks:
self.registry.register(check, *getattr(check, "tags", ()), deploy=True)
self.registry.register(check, *getattr(check, 'tags', ()), deploy=True)
def disable(self):
self.registry.registered_checks = self.old_checks
@@ -644,18 +559,18 @@ def compare_xml(want, got):
Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py
"""
_norm_whitespace_re = re.compile(r"[ \t\n][ \t\n]+")
_norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
def norm_whitespace(v):
return _norm_whitespace_re.sub(" ", v)
return _norm_whitespace_re.sub(' ', v)
def child_text(element):
return "".join(
c.data for c in element.childNodes if c.nodeType == Node.TEXT_NODE
)
return ''.join(c.data for c in element.childNodes
if c.nodeType == Node.TEXT_NODE)
def children(element):
return [c for c in element.childNodes if c.nodeType == Node.ELEMENT_NODE]
return [c for c in element.childNodes
if c.nodeType == Node.ELEMENT_NODE]
def norm_child_text(element):
return norm_whitespace(child_text(element))
@@ -674,9 +589,7 @@ def compare_xml(want, got):
got_children = children(got_element)
if len(want_children) != len(got_children):
return False
return all(
check_element(want, got) for want, got in zip(want_children, got_children)
)
return all(check_element(want, got) for want, got in zip(want_children, got_children))
def first_node(document):
for node in document.childNodes:
@@ -687,13 +600,13 @@ def compare_xml(want, got):
):
return node
want = want.strip().replace("\\n", "\n")
got = got.strip().replace("\\n", "\n")
want = want.strip().replace('\\n', '\n')
got = got.strip().replace('\\n', '\n')
# If the string is not a complete xml document, we may need to add a
# root element. This allow us to compare fragments, like "<foo/><bar/>"
if not want.startswith("<?xml"):
wrapper = "<root>%s</root>"
if not want.startswith('<?xml'):
wrapper = '<root>%s</root>'
want = wrapper % want
got = wrapper % got
@@ -708,7 +621,6 @@ class CaptureQueriesContext:
"""
Context manager that captures queries executed by the specified connection.
"""
def __init__(self, connection):
self.connection = connection
@@ -723,7 +635,7 @@ class CaptureQueriesContext:
@property
def captured_queries(self):
return self.connection.queries[self.initial_queries : self.final_queries]
return self.connection.queries[self.initial_queries:self.final_queries]
def __enter__(self):
self.force_debug_cursor = self.connection.force_debug_cursor
@@ -747,7 +659,7 @@ class CaptureQueriesContext:
class ignore_warnings(TestContextDecorator):
def __init__(self, **kwargs):
self.ignore_kwargs = kwargs
if "message" in self.ignore_kwargs or "module" in self.ignore_kwargs:
if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs:
self.filter_func = warnings.filterwarnings
else:
self.filter_func = warnings.simplefilter
@@ -756,7 +668,7 @@ class ignore_warnings(TestContextDecorator):
def enable(self):
self.catch_warnings = warnings.catch_warnings()
self.catch_warnings.__enter__()
self.filter_func("ignore", **self.ignore_kwargs)
self.filter_func('ignore', **self.ignore_kwargs)
def disable(self):
self.catch_warnings.__exit__(*sys.exc_info())
@@ -770,7 +682,7 @@ class ignore_warnings(TestContextDecorator):
requires_tz_support = skipUnless(
TZ_SUPPORT,
"This test relies on the ability to run a program in an arbitrary "
"time zone, but your operating system isn't able to do that.",
"time zone, but your operating system isn't able to do that."
)
@@ -813,9 +725,9 @@ def captured_output(stream_name):
def captured_stdout():
"""Capture the output of sys.stdout:
with captured_stdout() as stdout:
print("hello")
self.assertEqual(stdout.getvalue(), "hello\n")
with captured_stdout() as stdout:
print("hello")
self.assertEqual(stdout.getvalue(), "hello\n")
"""
return captured_output("stdout")
@@ -823,9 +735,9 @@ def captured_stdout():
def captured_stderr():
"""Capture the output of sys.stderr:
with captured_stderr() as stderr:
print("hello", file=sys.stderr)
self.assertEqual(stderr.getvalue(), "hello\n")
with captured_stderr() as stderr:
print("hello", file=sys.stderr)
self.assertEqual(stderr.getvalue(), "hello\n")
"""
return captured_output("stderr")
@@ -833,12 +745,12 @@ def captured_stderr():
def captured_stdin():
"""Capture the input to sys.stdin:
with captured_stdin() as stdin:
stdin.write('hello\n')
stdin.seek(0)
# call test code that consumes from sys.stdin
captured = input()
self.assertEqual(captured, "hello")
with captured_stdin() as stdin:
stdin.write('hello\n')
stdin.seek(0)
# call test code that consumes from sys.stdin
captured = input()
self.assertEqual(captured, "hello")
"""
return captured_output("stdin")
@@ -866,24 +778,18 @@ def require_jinja2(test_func):
Django template engine for a test or skip it if Jinja2 isn't available.
"""
test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func)
return override_settings(
TEMPLATES=[
{
"BACKEND": "django.template.backends.django.DjangoTemplates",
"APP_DIRS": True,
},
{
"BACKEND": "django.template.backends.jinja2.Jinja2",
"APP_DIRS": True,
"OPTIONS": {"keep_trailing_newline": True},
},
]
)(test_func)
return override_settings(TEMPLATES=[{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'APP_DIRS': True,
}, {
'BACKEND': 'django.template.backends.jinja2.Jinja2',
'APP_DIRS': True,
'OPTIONS': {'keep_trailing_newline': True},
}])(test_func)
class override_script_prefix(TestContextDecorator):
"""Decorator or context manager to temporary override the script prefix."""
def __init__(self, prefix):
self.prefix = prefix
super().__init__()
@@ -901,9 +807,8 @@ class LoggingCaptureMixin:
Capture the output from the 'django' logger and store it on the class's
logger_output attribute.
"""
def setUp(self):
self.logger = logging.getLogger("django")
self.logger = logging.getLogger('django')
self.old_stream = self.logger.handlers[0].stream
self.logger_output = StringIO()
self.logger.handlers[0].stream = self.logger_output
@@ -928,7 +833,6 @@ class isolate_apps(TestContextDecorator):
`kwarg_name`: keyword argument passing the isolated registry if used as a
function decorator.
"""
def __init__(self, *installed_apps, **kwargs):
self.installed_apps = installed_apps
super().__init__(**kwargs)
@@ -936,11 +840,11 @@ class isolate_apps(TestContextDecorator):
def enable(self):
self.old_apps = Options.default_apps
apps = Apps(self.installed_apps)
setattr(Options, "default_apps", apps)
setattr(Options, 'default_apps', apps)
return apps
def disable(self):
setattr(Options, "default_apps", self.old_apps)
setattr(Options, 'default_apps', self.old_apps)
class TimeKeeper:
@@ -960,7 +864,7 @@ class TimeKeeper:
def print_results(self):
for name, end_times in self.records.items():
for record_time in end_times:
record = "%s took %.3fs" % (name, record_time)
record = '%s took %.3fs' % (name, record_time)
sys.stderr.write(record + os.linesep)
@@ -975,14 +879,12 @@ class NullTimeKeeper:
def tag(*tags):
"""Decorator to add tags to a test class or method."""
def decorator(obj):
if hasattr(obj, "tags"):
if hasattr(obj, 'tags'):
obj.tags = obj.tags.union(tags)
else:
setattr(obj, "tags", set(tags))
setattr(obj, 'tags', set(tags))
return obj
return decorator