测试gitnore
This commit is contained in:
@@ -1,38 +1,21 @@
|
||||
"""Django Unit Test framework."""
|
||||
|
||||
from django.test.client import AsyncClient, AsyncRequestFactory, Client, RequestFactory
|
||||
from django.test.client import (
|
||||
AsyncClient, AsyncRequestFactory, Client, RequestFactory,
|
||||
)
|
||||
from django.test.testcases import (
|
||||
LiveServerTestCase,
|
||||
SimpleTestCase,
|
||||
TestCase,
|
||||
TransactionTestCase,
|
||||
skipIfDBFeature,
|
||||
skipUnlessAnyDBFeature,
|
||||
skipUnlessDBFeature,
|
||||
LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase,
|
||||
skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature,
|
||||
)
|
||||
from django.test.utils import (
|
||||
ignore_warnings,
|
||||
modify_settings,
|
||||
override_settings,
|
||||
override_system_checks,
|
||||
tag,
|
||||
ignore_warnings, modify_settings, override_settings,
|
||||
override_system_checks, tag,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AsyncClient",
|
||||
"AsyncRequestFactory",
|
||||
"Client",
|
||||
"RequestFactory",
|
||||
"TestCase",
|
||||
"TransactionTestCase",
|
||||
"SimpleTestCase",
|
||||
"LiveServerTestCase",
|
||||
"skipIfDBFeature",
|
||||
"skipUnlessAnyDBFeature",
|
||||
"skipUnlessDBFeature",
|
||||
"ignore_warnings",
|
||||
"modify_settings",
|
||||
"override_settings",
|
||||
"override_system_checks",
|
||||
"tag",
|
||||
'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory',
|
||||
'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase',
|
||||
'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature',
|
||||
'ignore_warnings', 'modify_settings', 'override_settings',
|
||||
'override_system_checks', 'tag',
|
||||
]
|
||||
|
||||
@@ -16,7 +16,9 @@ from django.core.handlers.asgi import ASGIRequest
|
||||
from django.core.handlers.base import BaseHandler
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.core.signals import got_request_exception, request_finished, request_started
|
||||
from django.core.signals import (
|
||||
got_request_exception, request_finished, request_started,
|
||||
)
|
||||
from django.db import close_old_connections
|
||||
from django.http import HttpRequest, QueryDict, SimpleCookie
|
||||
from django.test import signals
|
||||
@@ -29,26 +31,20 @@ from django.utils.itercompat import is_iterable
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
__all__ = (
|
||||
"AsyncClient",
|
||||
"AsyncRequestFactory",
|
||||
"Client",
|
||||
"RedirectCycleError",
|
||||
"RequestFactory",
|
||||
"encode_file",
|
||||
"encode_multipart",
|
||||
'AsyncClient', 'AsyncRequestFactory', 'Client', 'RedirectCycleError',
|
||||
'RequestFactory', 'encode_file', 'encode_multipart',
|
||||
)
|
||||
|
||||
|
||||
BOUNDARY = "BoUnDaRyStRiNg"
|
||||
MULTIPART_CONTENT = "multipart/form-data; boundary=%s" % BOUNDARY
|
||||
CONTENT_TYPE_RE = _lazy_re_compile(r".*; charset=([\w\d-]+);?")
|
||||
BOUNDARY = 'BoUnDaRyStRiNg'
|
||||
MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
|
||||
CONTENT_TYPE_RE = _lazy_re_compile(r'.*; charset=([\w\d-]+);?')
|
||||
# Structured suffix spec: https://tools.ietf.org/html/rfc6838#section-4.2.8
|
||||
JSON_CONTENT_TYPE_RE = _lazy_re_compile(r"^application\/(.+\+)?json")
|
||||
JSON_CONTENT_TYPE_RE = _lazy_re_compile(r'^application\/(.+\+)?json')
|
||||
|
||||
|
||||
class RedirectCycleError(Exception):
|
||||
"""The test client has been asked to follow a redirect loop."""
|
||||
|
||||
def __init__(self, message, last_response):
|
||||
super().__init__(message)
|
||||
self.last_response = last_response
|
||||
@@ -62,7 +58,6 @@ class FakePayload:
|
||||
length. This makes sure that views can't do anything under the test client
|
||||
that wouldn't work in real life.
|
||||
"""
|
||||
|
||||
def __init__(self, content=None):
|
||||
self.__content = BytesIO()
|
||||
self.__len = 0
|
||||
@@ -79,9 +74,7 @@ class FakePayload:
|
||||
self.read_started = True
|
||||
if num_bytes is None:
|
||||
num_bytes = self.__len or 0
|
||||
assert (
|
||||
self.__len >= num_bytes
|
||||
), "Cannot read more than the available bytes from the HTTP incoming data."
|
||||
assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data."
|
||||
content = self.__content.read(num_bytes)
|
||||
self.__len -= num_bytes
|
||||
return content
|
||||
@@ -99,13 +92,13 @@ def closing_iterator_wrapper(iterable, close):
|
||||
yield from iterable
|
||||
finally:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
close() # will fire request_finished
|
||||
close() # will fire request_finished
|
||||
request_finished.connect(close_old_connections)
|
||||
|
||||
|
||||
def conditional_content_removal(request, response):
|
||||
"""
|
||||
Simulate the behavior of most web servers by removing the content of
|
||||
Simulate the behavior of most Web servers by removing the content of
|
||||
responses for HEAD requests, 1xx, 204, and 304 responses. Ensure
|
||||
compliance with RFC 7230, section 3.3.3.
|
||||
"""
|
||||
@@ -113,22 +106,21 @@ def conditional_content_removal(request, response):
|
||||
if response.streaming:
|
||||
response.streaming_content = []
|
||||
else:
|
||||
response.content = b""
|
||||
if request.method == "HEAD":
|
||||
response.content = b''
|
||||
if request.method == 'HEAD':
|
||||
if response.streaming:
|
||||
response.streaming_content = []
|
||||
else:
|
||||
response.content = b""
|
||||
response.content = b''
|
||||
return response
|
||||
|
||||
|
||||
class ClientHandler(BaseHandler):
|
||||
"""
|
||||
An HTTP Handler that can be used for testing purposes. Use the WSGI
|
||||
A HTTP Handler that can be used for testing purposes. Use the WSGI
|
||||
interface to compose requests, but return the raw HttpResponse object with
|
||||
the originating WSGIRequest attached to its ``wsgi_request`` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
|
||||
self.enforce_csrf_checks = enforce_csrf_checks
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -152,7 +144,7 @@ class ClientHandler(BaseHandler):
|
||||
# Request goes through middleware.
|
||||
response = self.get_response(request)
|
||||
|
||||
# Simulate behaviors of most web servers.
|
||||
# Simulate behaviors of most Web servers.
|
||||
conditional_content_removal(request, response)
|
||||
|
||||
# Attach the originating request to the response so that it could be
|
||||
@@ -162,11 +154,10 @@ class ClientHandler(BaseHandler):
|
||||
# Emulate a WSGI server by calling the close method on completion.
|
||||
if response.streaming:
|
||||
response.streaming_content = closing_iterator_wrapper(
|
||||
response.streaming_content, response.close
|
||||
)
|
||||
response.streaming_content, response.close)
|
||||
else:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
response.close() # will fire request_finished
|
||||
response.close() # will fire request_finished
|
||||
request_finished.connect(close_old_connections)
|
||||
|
||||
return response
|
||||
@@ -174,7 +165,6 @@ class ClientHandler(BaseHandler):
|
||||
|
||||
class AsyncClientHandler(BaseHandler):
|
||||
"""An async version of ClientHandler."""
|
||||
|
||||
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
|
||||
self.enforce_csrf_checks = enforce_csrf_checks
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -185,15 +175,13 @@ class AsyncClientHandler(BaseHandler):
|
||||
if self._middleware_chain is None:
|
||||
self.load_middleware(is_async=True)
|
||||
# Extract body file from the scope, if provided.
|
||||
if "_body_file" in scope:
|
||||
body_file = scope.pop("_body_file")
|
||||
if '_body_file' in scope:
|
||||
body_file = scope.pop('_body_file')
|
||||
else:
|
||||
body_file = FakePayload("")
|
||||
body_file = FakePayload('')
|
||||
|
||||
request_started.disconnect(close_old_connections)
|
||||
await sync_to_async(request_started.send, thread_sensitive=False)(
|
||||
sender=self.__class__, scope=scope
|
||||
)
|
||||
await sync_to_async(request_started.send, thread_sensitive=False)(sender=self.__class__, scope=scope)
|
||||
request_started.connect(close_old_connections)
|
||||
request = ASGIRequest(scope, body_file)
|
||||
# Sneaky little hack so that we can easily get round
|
||||
@@ -202,16 +190,14 @@ class AsyncClientHandler(BaseHandler):
|
||||
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
||||
# Request goes through middleware.
|
||||
response = await self.get_response_async(request)
|
||||
# Simulate behaviors of most web servers.
|
||||
# Simulate behaviors of most Web servers.
|
||||
conditional_content_removal(request, response)
|
||||
# Attach the originating ASGI request to the response so that it could
|
||||
# be later retrieved.
|
||||
response.asgi_request = request
|
||||
# Emulate a server by calling the close method on completion.
|
||||
if response.streaming:
|
||||
response.streaming_content = await sync_to_async(
|
||||
closing_iterator_wrapper, thread_sensitive=False
|
||||
)(
|
||||
response.streaming_content = await sync_to_async(closing_iterator_wrapper, thread_sensitive=False)(
|
||||
response.streaming_content,
|
||||
response.close,
|
||||
)
|
||||
@@ -230,10 +216,10 @@ def store_rendered_templates(store, signal, sender, template, context, **kwargs)
|
||||
The context is copied so that it is an accurate representation at the time
|
||||
of rendering.
|
||||
"""
|
||||
store.setdefault("templates", []).append(template)
|
||||
if "context" not in store:
|
||||
store["context"] = ContextList()
|
||||
store["context"].append(copy(context))
|
||||
store.setdefault('templates', []).append(template)
|
||||
if 'context' not in store:
|
||||
store['context'] = ContextList()
|
||||
store['context'].append(copy(context))
|
||||
|
||||
|
||||
def encode_multipart(boundary, data):
|
||||
@@ -269,33 +255,25 @@ def encode_multipart(boundary, data):
|
||||
if is_file(item):
|
||||
lines.extend(encode_file(boundary, key, item))
|
||||
else:
|
||||
lines.extend(
|
||||
to_bytes(val)
|
||||
for val in [
|
||||
"--%s" % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
"",
|
||||
item,
|
||||
]
|
||||
)
|
||||
lines.extend(to_bytes(val) for val in [
|
||||
'--%s' % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
'',
|
||||
item
|
||||
])
|
||||
else:
|
||||
lines.extend(
|
||||
to_bytes(val)
|
||||
for val in [
|
||||
"--%s" % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
"",
|
||||
value,
|
||||
]
|
||||
)
|
||||
lines.extend(to_bytes(val) for val in [
|
||||
'--%s' % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
'',
|
||||
value
|
||||
])
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
to_bytes("--%s--" % boundary),
|
||||
b"",
|
||||
]
|
||||
)
|
||||
return b"\r\n".join(lines)
|
||||
lines.extend([
|
||||
to_bytes('--%s--' % boundary),
|
||||
b'',
|
||||
])
|
||||
return b'\r\n'.join(lines)
|
||||
|
||||
|
||||
def encode_file(boundary, key, file):
|
||||
@@ -304,10 +282,10 @@ def encode_file(boundary, key, file):
|
||||
|
||||
# file.name might not be a string. For example, it's an int for
|
||||
# tempfile.TemporaryFile().
|
||||
file_has_string_name = hasattr(file, "name") and isinstance(file.name, str)
|
||||
filename = os.path.basename(file.name) if file_has_string_name else ""
|
||||
file_has_string_name = hasattr(file, 'name') and isinstance(file.name, str)
|
||||
filename = os.path.basename(file.name) if file_has_string_name else ''
|
||||
|
||||
if hasattr(file, "content_type"):
|
||||
if hasattr(file, 'content_type'):
|
||||
content_type = file.content_type
|
||||
elif filename:
|
||||
content_type = mimetypes.guess_type(filename)[0]
|
||||
@@ -315,16 +293,15 @@ def encode_file(boundary, key, file):
|
||||
content_type = None
|
||||
|
||||
if content_type is None:
|
||||
content_type = "application/octet-stream"
|
||||
content_type = 'application/octet-stream'
|
||||
filename = filename or key
|
||||
return [
|
||||
to_bytes("--%s" % boundary),
|
||||
to_bytes(
|
||||
'Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename)
|
||||
),
|
||||
to_bytes("Content-Type: %s" % content_type),
|
||||
b"",
|
||||
to_bytes(file.read()),
|
||||
to_bytes('--%s' % boundary),
|
||||
to_bytes('Content-Disposition: form-data; name="%s"; filename="%s"'
|
||||
% (key, filename)),
|
||||
to_bytes('Content-Type: %s' % content_type),
|
||||
b'',
|
||||
to_bytes(file.read())
|
||||
]
|
||||
|
||||
|
||||
@@ -341,7 +318,6 @@ class RequestFactory:
|
||||
Once you have a request object you can pass it to any view function,
|
||||
just as if that view had been hooked up using a URLconf.
|
||||
"""
|
||||
|
||||
def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults):
|
||||
self.json_encoder = json_encoder
|
||||
self.defaults = defaults
|
||||
@@ -357,26 +333,24 @@ class RequestFactory:
|
||||
# - REMOTE_ADDR: often useful, see #8551.
|
||||
# See https://www.python.org/dev/peps/pep-3333/#environ-variables
|
||||
return {
|
||||
"HTTP_COOKIE": "; ".join(
|
||||
sorted(
|
||||
"%s=%s" % (morsel.key, morsel.coded_value)
|
||||
for morsel in self.cookies.values()
|
||||
)
|
||||
),
|
||||
"PATH_INFO": "/",
|
||||
"REMOTE_ADDR": "127.0.0.1",
|
||||
"REQUEST_METHOD": "GET",
|
||||
"SCRIPT_NAME": "",
|
||||
"SERVER_NAME": "testserver",
|
||||
"SERVER_PORT": "80",
|
||||
"SERVER_PROTOCOL": "HTTP/1.1",
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": "http",
|
||||
"wsgi.input": FakePayload(b""),
|
||||
"wsgi.errors": self.errors,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.multithread": False,
|
||||
"wsgi.run_once": False,
|
||||
'HTTP_COOKIE': '; '.join(sorted(
|
||||
'%s=%s' % (morsel.key, morsel.coded_value)
|
||||
for morsel in self.cookies.values()
|
||||
)),
|
||||
'PATH_INFO': '/',
|
||||
'REMOTE_ADDR': '127.0.0.1',
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'SCRIPT_NAME': '',
|
||||
'SERVER_NAME': 'testserver',
|
||||
'SERVER_PORT': '80',
|
||||
'SERVER_PROTOCOL': 'HTTP/1.1',
|
||||
'wsgi.version': (1, 0),
|
||||
'wsgi.url_scheme': 'http',
|
||||
'wsgi.input': FakePayload(b''),
|
||||
'wsgi.errors': self.errors,
|
||||
'wsgi.multiprocess': True,
|
||||
'wsgi.multithread': False,
|
||||
'wsgi.run_once': False,
|
||||
**self.defaults,
|
||||
**request,
|
||||
}
|
||||
@@ -402,9 +376,7 @@ class RequestFactory:
|
||||
Return encoded JSON if data is a dict, list, or tuple and content_type
|
||||
is application/json.
|
||||
"""
|
||||
should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(
|
||||
data, (dict, list, tuple)
|
||||
)
|
||||
should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(data, (dict, list, tuple))
|
||||
return json.dumps(data, cls=self.json_encoder) if should_encode else data
|
||||
|
||||
def _get_path(self, parsed):
|
||||
@@ -416,128 +388,88 @@ class RequestFactory:
|
||||
# Replace the behavior where non-ASCII values in the WSGI environ are
|
||||
# arbitrarily decoded with ISO-8859-1.
|
||||
# Refs comment in `get_bytes_from_wsgi()`.
|
||||
return path.decode("iso-8859-1")
|
||||
return path.decode('iso-8859-1')
|
||||
|
||||
def get(self, path, data=None, secure=False, **extra):
|
||||
"""Construct a GET request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic(
|
||||
"GET",
|
||||
path,
|
||||
secure=secure,
|
||||
**{
|
||||
"QUERY_STRING": urlencode(data, doseq=True),
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
return self.generic('GET', path, secure=secure, **{
|
||||
'QUERY_STRING': urlencode(data, doseq=True),
|
||||
**extra,
|
||||
})
|
||||
|
||||
def post(
|
||||
self, path, data=None, content_type=MULTIPART_CONTENT, secure=False, **extra
|
||||
):
|
||||
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
|
||||
secure=False, **extra):
|
||||
"""Construct a POST request."""
|
||||
data = self._encode_json({} if data is None else data, content_type)
|
||||
post_data = self._encode_data(data, content_type)
|
||||
|
||||
return self.generic(
|
||||
"POST", path, post_data, content_type, secure=secure, **extra
|
||||
)
|
||||
return self.generic('POST', path, post_data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def head(self, path, data=None, secure=False, **extra):
|
||||
"""Construct a HEAD request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic(
|
||||
"HEAD",
|
||||
path,
|
||||
secure=secure,
|
||||
**{
|
||||
"QUERY_STRING": urlencode(data, doseq=True),
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
return self.generic('HEAD', path, secure=secure, **{
|
||||
'QUERY_STRING': urlencode(data, doseq=True),
|
||||
**extra,
|
||||
})
|
||||
|
||||
def trace(self, path, secure=False, **extra):
|
||||
"""Construct a TRACE request."""
|
||||
return self.generic("TRACE", path, secure=secure, **extra)
|
||||
return self.generic('TRACE', path, secure=secure, **extra)
|
||||
|
||||
def options(
|
||||
self,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def options(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"Construct an OPTIONS request."
|
||||
return self.generic("OPTIONS", path, data, content_type, secure=secure, **extra)
|
||||
return self.generic('OPTIONS', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def put(
|
||||
self,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def put(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a PUT request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic("PUT", path, data, content_type, secure=secure, **extra)
|
||||
return self.generic('PUT', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def patch(
|
||||
self,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def patch(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a PATCH request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic("PATCH", path, data, content_type, secure=secure, **extra)
|
||||
return self.generic('PATCH', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def delete(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a DELETE request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic("DELETE", path, data, content_type, secure=secure, **extra)
|
||||
return self.generic('DELETE', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def generic(
|
||||
self,
|
||||
method,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def generic(self, method, path, data='',
|
||||
content_type='application/octet-stream', secure=False,
|
||||
**extra):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
parsed = urlparse(str(path)) # path can be lazy
|
||||
data = force_bytes(data, settings.DEFAULT_CHARSET)
|
||||
r = {
|
||||
"PATH_INFO": self._get_path(parsed),
|
||||
"REQUEST_METHOD": method,
|
||||
"SERVER_PORT": "443" if secure else "80",
|
||||
"wsgi.url_scheme": "https" if secure else "http",
|
||||
'PATH_INFO': self._get_path(parsed),
|
||||
'REQUEST_METHOD': method,
|
||||
'SERVER_PORT': '443' if secure else '80',
|
||||
'wsgi.url_scheme': 'https' if secure else 'http',
|
||||
}
|
||||
if data:
|
||||
r.update(
|
||||
{
|
||||
"CONTENT_LENGTH": str(len(data)),
|
||||
"CONTENT_TYPE": content_type,
|
||||
"wsgi.input": FakePayload(data),
|
||||
}
|
||||
)
|
||||
r.update({
|
||||
'CONTENT_LENGTH': str(len(data)),
|
||||
'CONTENT_TYPE': content_type,
|
||||
'wsgi.input': FakePayload(data),
|
||||
})
|
||||
r.update(extra)
|
||||
# If QUERY_STRING is absent or empty, we want to extract it from the URL.
|
||||
if not r.get("QUERY_STRING"):
|
||||
if not r.get('QUERY_STRING'):
|
||||
# WSGI requires latin-1 encoded strings. See get_path_info().
|
||||
query_string = parsed[4].encode().decode("iso-8859-1")
|
||||
r["QUERY_STRING"] = query_string
|
||||
query_string = parsed[4].encode().decode('iso-8859-1')
|
||||
r['QUERY_STRING'] = query_string
|
||||
return self.request(**r)
|
||||
|
||||
|
||||
@@ -555,35 +487,30 @@ class AsyncRequestFactory(RequestFactory):
|
||||
a) this makes ASGIRequest subclasses, and
|
||||
b) AsyncTestClient can subclass it.
|
||||
"""
|
||||
|
||||
def _base_scope(self, **request):
|
||||
"""The base scope for a request."""
|
||||
# This is a minimal valid ASGI scope, plus:
|
||||
# - headers['cookie'] for cookie support,
|
||||
# - 'client' often useful, see #8551.
|
||||
scope = {
|
||||
"asgi": {"version": "3.0"},
|
||||
"type": "http",
|
||||
"http_version": "1.1",
|
||||
"client": ["127.0.0.1", 0],
|
||||
"server": ("testserver", "80"),
|
||||
"scheme": "http",
|
||||
"method": "GET",
|
||||
"headers": [],
|
||||
'asgi': {'version': '3.0'},
|
||||
'type': 'http',
|
||||
'http_version': '1.1',
|
||||
'client': ['127.0.0.1', 0],
|
||||
'server': ('testserver', '80'),
|
||||
'scheme': 'http',
|
||||
'method': 'GET',
|
||||
'headers': [],
|
||||
**self.defaults,
|
||||
**request,
|
||||
}
|
||||
scope["headers"].append(
|
||||
(
|
||||
b"cookie",
|
||||
b"; ".join(
|
||||
sorted(
|
||||
("%s=%s" % (morsel.key, morsel.coded_value)).encode("ascii")
|
||||
for morsel in self.cookies.values()
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
scope['headers'].append((
|
||||
b'cookie',
|
||||
b'; '.join(sorted(
|
||||
('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
|
||||
for morsel in self.cookies.values()
|
||||
)),
|
||||
))
|
||||
return scope
|
||||
|
||||
def request(self, **request):
|
||||
@@ -591,52 +518,43 @@ class AsyncRequestFactory(RequestFactory):
|
||||
# This is synchronous, which means all methods on this class are.
|
||||
# AsyncClient, however, has an async request function, which makes all
|
||||
# its methods async.
|
||||
if "_body_file" in request:
|
||||
body_file = request.pop("_body_file")
|
||||
if '_body_file' in request:
|
||||
body_file = request.pop('_body_file')
|
||||
else:
|
||||
body_file = FakePayload("")
|
||||
body_file = FakePayload('')
|
||||
return ASGIRequest(self._base_scope(**request), body_file)
|
||||
|
||||
def generic(
|
||||
self,
|
||||
method,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
secure=False,
|
||||
**extra,
|
||||
self, method, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra,
|
||||
):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
parsed = urlparse(str(path)) # path can be lazy.
|
||||
data = force_bytes(data, settings.DEFAULT_CHARSET)
|
||||
s = {
|
||||
"method": method,
|
||||
"path": self._get_path(parsed),
|
||||
"server": ("127.0.0.1", "443" if secure else "80"),
|
||||
"scheme": "https" if secure else "http",
|
||||
"headers": [(b"host", b"testserver")],
|
||||
'method': method,
|
||||
'path': self._get_path(parsed),
|
||||
'server': ('127.0.0.1', '443' if secure else '80'),
|
||||
'scheme': 'https' if secure else 'http',
|
||||
'headers': [(b'host', b'testserver')],
|
||||
}
|
||||
if data:
|
||||
s["headers"].extend(
|
||||
[
|
||||
(b"content-length", str(len(data)).encode("ascii")),
|
||||
(b"content-type", content_type.encode("ascii")),
|
||||
]
|
||||
)
|
||||
s["_body_file"] = FakePayload(data)
|
||||
follow = extra.pop("follow", None)
|
||||
s['headers'].extend([
|
||||
(b'content-length', str(len(data)).encode('ascii')),
|
||||
(b'content-type', content_type.encode('ascii')),
|
||||
])
|
||||
s['_body_file'] = FakePayload(data)
|
||||
follow = extra.pop('follow', None)
|
||||
if follow is not None:
|
||||
s["follow"] = follow
|
||||
if query_string := extra.pop("QUERY_STRING", None):
|
||||
s["query_string"] = query_string
|
||||
s["headers"] += [
|
||||
(key.lower().encode("ascii"), value.encode("latin1"))
|
||||
s['follow'] = follow
|
||||
s['headers'] += [
|
||||
(key.lower().encode('ascii'), value.encode('latin1'))
|
||||
for key, value in extra.items()
|
||||
]
|
||||
# If QUERY_STRING is absent or empty, we want to extract it from the
|
||||
# URL.
|
||||
if not s.get("query_string"):
|
||||
s["query_string"] = parsed[4]
|
||||
if not s.get('query_string'):
|
||||
s['query_string'] = parsed[4]
|
||||
return self.request(**s)
|
||||
|
||||
|
||||
@@ -644,7 +562,6 @@ class ClientMixin:
|
||||
"""
|
||||
Mixin with common methods between Client and AsyncClient.
|
||||
"""
|
||||
|
||||
def store_exc_info(self, **kwargs):
|
||||
"""Store exceptions when they are generated by a view."""
|
||||
self.exc_info = sys.exc_info()
|
||||
@@ -682,7 +599,6 @@ class ClientMixin:
|
||||
are incorrect.
|
||||
"""
|
||||
from django.contrib.auth import authenticate
|
||||
|
||||
user = authenticate(**credentials)
|
||||
if user:
|
||||
self._login(user)
|
||||
@@ -692,10 +608,9 @@ class ClientMixin:
|
||||
def force_login(self, user, backend=None):
|
||||
def get_backend():
|
||||
from django.contrib.auth import load_backend
|
||||
|
||||
for backend_path in settings.AUTHENTICATION_BACKENDS:
|
||||
backend = load_backend(backend_path)
|
||||
if hasattr(backend, "get_user"):
|
||||
if hasattr(backend, 'get_user'):
|
||||
return backend_path
|
||||
|
||||
if backend is None:
|
||||
@@ -720,18 +635,17 @@ class ClientMixin:
|
||||
session_cookie = settings.SESSION_COOKIE_NAME
|
||||
self.cookies[session_cookie] = request.session.session_key
|
||||
cookie_data = {
|
||||
"max-age": None,
|
||||
"path": "/",
|
||||
"domain": settings.SESSION_COOKIE_DOMAIN,
|
||||
"secure": settings.SESSION_COOKIE_SECURE or None,
|
||||
"expires": None,
|
||||
'max-age': None,
|
||||
'path': '/',
|
||||
'domain': settings.SESSION_COOKIE_DOMAIN,
|
||||
'secure': settings.SESSION_COOKIE_SECURE or None,
|
||||
'expires': None,
|
||||
}
|
||||
self.cookies[session_cookie].update(cookie_data)
|
||||
|
||||
def logout(self):
|
||||
"""Log out the user by removing the cookies and session object."""
|
||||
from django.contrib.auth import get_user, logout
|
||||
|
||||
request = HttpRequest()
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
@@ -743,15 +657,13 @@ class ClientMixin:
|
||||
self.cookies = SimpleCookie()
|
||||
|
||||
def _parse_json(self, response, **extra):
|
||||
if not hasattr(response, "_json"):
|
||||
if not JSON_CONTENT_TYPE_RE.match(response.get("Content-Type")):
|
||||
if not hasattr(response, '_json'):
|
||||
if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
|
||||
raise ValueError(
|
||||
'Content-Type header is "%s", not "application/json"'
|
||||
% response.get("Content-Type")
|
||||
% response.get('Content-Type')
|
||||
)
|
||||
response._json = json.loads(
|
||||
response.content.decode(response.charset), **extra
|
||||
)
|
||||
response._json = json.loads(response.content.decode(response.charset), **extra)
|
||||
return response._json
|
||||
|
||||
|
||||
@@ -773,10 +685,7 @@ class Client(ClientMixin, RequestFactory):
|
||||
contexts and templates produced by a view, rather than the
|
||||
HTML rendered to the end-user.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
|
||||
):
|
||||
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
|
||||
super().__init__(**defaults)
|
||||
self.handler = ClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
@@ -812,14 +721,11 @@ class Client(ClientMixin, RequestFactory):
|
||||
response.client = self
|
||||
response.request = request
|
||||
# Add any rendered template detail to the response.
|
||||
response.templates = data.get("templates", [])
|
||||
response.context = data.get("context")
|
||||
response.templates = data.get('templates', [])
|
||||
response.context = data.get('context')
|
||||
response.json = partial(self._parse_json, response)
|
||||
# Attach the ResolverMatch instance to the response.
|
||||
urlconf = getattr(response.wsgi_request, "urlconf", None)
|
||||
response.resolver_match = SimpleLazyObject(
|
||||
lambda: resolve(request["PATH_INFO"], urlconf=urlconf),
|
||||
)
|
||||
response.resolver_match = SimpleLazyObject(lambda: resolve(request['PATH_INFO']))
|
||||
# Flatten a single context. Not really necessary anymore thanks to the
|
||||
# __getattr__ flattening in ContextList, but has some edge case
|
||||
# backwards compatibility implications.
|
||||
@@ -838,24 +744,13 @@ class Client(ClientMixin, RequestFactory):
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def post(
|
||||
self,
|
||||
path,
|
||||
data=None,
|
||||
content_type=MULTIPART_CONTENT,
|
||||
follow=False,
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using POST."""
|
||||
self.extra = extra
|
||||
response = super().post(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
)
|
||||
response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
)
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def head(self, path, data=None, follow=False, secure=False, **extra):
|
||||
@@ -866,87 +761,43 @@ class Client(ClientMixin, RequestFactory):
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def options(
|
||||
self,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
follow=False,
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def options(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using OPTIONS."""
|
||||
self.extra = extra
|
||||
response = super().options(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
)
|
||||
response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
)
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def put(
|
||||
self,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
follow=False,
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def put(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PUT."""
|
||||
self.extra = extra
|
||||
response = super().put(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
)
|
||||
response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
)
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def patch(
|
||||
self,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
follow=False,
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def patch(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PATCH."""
|
||||
self.extra = extra
|
||||
response = super().patch(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
)
|
||||
response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
)
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def delete(
|
||||
self,
|
||||
path,
|
||||
data="",
|
||||
content_type="application/octet-stream",
|
||||
follow=False,
|
||||
secure=False,
|
||||
**extra,
|
||||
):
|
||||
def delete(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a DELETE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().delete(
|
||||
path, data=data, content_type=content_type, secure=secure, **extra
|
||||
)
|
||||
response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(
|
||||
response, data=data, content_type=content_type, **extra
|
||||
)
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def trace(self, path, data="", follow=False, secure=False, **extra):
|
||||
def trace(self, path, data='', follow=False, secure=False, **extra):
|
||||
"""Send a TRACE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().trace(path, data=data, secure=secure, **extra)
|
||||
@@ -954,7 +805,7 @@ class Client(ClientMixin, RequestFactory):
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def _handle_redirects(self, response, data="", content_type="", **extra):
|
||||
def _handle_redirects(self, response, data='', content_type='', **extra):
|
||||
"""
|
||||
Follow any redirects by requesting responses from the server using GET.
|
||||
"""
|
||||
@@ -973,46 +824,36 @@ class Client(ClientMixin, RequestFactory):
|
||||
|
||||
url = urlsplit(response_url)
|
||||
if url.scheme:
|
||||
extra["wsgi.url_scheme"] = url.scheme
|
||||
extra['wsgi.url_scheme'] = url.scheme
|
||||
if url.hostname:
|
||||
extra["SERVER_NAME"] = url.hostname
|
||||
extra['SERVER_NAME'] = url.hostname
|
||||
if url.port:
|
||||
extra["SERVER_PORT"] = str(url.port)
|
||||
extra['SERVER_PORT'] = str(url.port)
|
||||
|
||||
path = url.path
|
||||
# RFC 2616: bare domains without path are treated as the root.
|
||||
if not path and url.netloc:
|
||||
path = "/"
|
||||
# Prepend the request path to handle relative path redirects
|
||||
if not path.startswith("/"):
|
||||
path = urljoin(response.request["PATH_INFO"], path)
|
||||
path = url.path
|
||||
if not path.startswith('/'):
|
||||
path = urljoin(response.request['PATH_INFO'], path)
|
||||
|
||||
if response.status_code in (
|
||||
HTTPStatus.TEMPORARY_REDIRECT,
|
||||
HTTPStatus.PERMANENT_REDIRECT,
|
||||
):
|
||||
if response.status_code in (HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT):
|
||||
# Preserve request method and query string (if needed)
|
||||
# post-redirect for 307/308 responses.
|
||||
request_method = response.request["REQUEST_METHOD"].lower()
|
||||
if request_method not in ("get", "head"):
|
||||
extra["QUERY_STRING"] = url.query
|
||||
request_method = response.request['REQUEST_METHOD'].lower()
|
||||
if request_method not in ('get', 'head'):
|
||||
extra['QUERY_STRING'] = url.query
|
||||
request_method = getattr(self, request_method)
|
||||
else:
|
||||
request_method = self.get
|
||||
data = QueryDict(url.query)
|
||||
content_type = None
|
||||
|
||||
response = request_method(
|
||||
path, data=data, content_type=content_type, follow=False, **extra
|
||||
)
|
||||
response = request_method(path, data=data, content_type=content_type, follow=False, **extra)
|
||||
response.redirect_chain = redirect_chain
|
||||
|
||||
if redirect_chain[-1] in redirect_chain[:-1]:
|
||||
# Check that we're not redirecting to somewhere we've already
|
||||
# been to, to prevent loops.
|
||||
raise RedirectCycleError(
|
||||
"Redirect loop detected.", last_response=response
|
||||
)
|
||||
raise RedirectCycleError("Redirect loop detected.", last_response=response)
|
||||
if len(redirect_chain) > 20:
|
||||
# Such a lengthy chain likely also means a loop, but one with
|
||||
# a growing path, changing view, or changing query argument;
|
||||
@@ -1029,10 +870,7 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
|
||||
|
||||
Does not currently support "follow" on its methods.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, enforce_csrf_checks=False, raise_request_exception=True, **defaults
|
||||
):
|
||||
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
|
||||
super().__init__(**defaults)
|
||||
self.handler = AsyncClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
@@ -1046,19 +884,20 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
|
||||
query environment, which can be overridden using the arguments to the
|
||||
request.
|
||||
"""
|
||||
if "follow" in request:
|
||||
if 'follow' in request:
|
||||
raise NotImplementedError(
|
||||
"AsyncClient request methods do not accept the follow parameter."
|
||||
'AsyncClient request methods do not accept the follow '
|
||||
'parameter.'
|
||||
)
|
||||
scope = self._base_scope(**request)
|
||||
# Curry a data dictionary into an instance of the template renderer
|
||||
# callback function.
|
||||
data = {}
|
||||
on_template_render = partial(store_rendered_templates, data)
|
||||
signal_uid = "template-render-%s" % id(request)
|
||||
signal_uid = 'template-render-%s' % id(request)
|
||||
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
|
||||
# Capture exceptions created by the handler.
|
||||
exception_uid = "request-exception-%s" % id(request)
|
||||
exception_uid = 'request-exception-%s' % id(request)
|
||||
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
|
||||
try:
|
||||
response = await self.handler(scope)
|
||||
@@ -1071,14 +910,11 @@ class AsyncClient(ClientMixin, AsyncRequestFactory):
|
||||
response.client = self
|
||||
response.request = request
|
||||
# Add any rendered template detail to the response.
|
||||
response.templates = data.get("templates", [])
|
||||
response.context = data.get("context")
|
||||
response.templates = data.get('templates', [])
|
||||
response.context = data.get('context')
|
||||
response.json = partial(self._parse_json, response)
|
||||
# Attach the ResolverMatch instance to the response.
|
||||
urlconf = getattr(response.asgi_request, "urlconf", None)
|
||||
response.resolver_match = SimpleLazyObject(
|
||||
lambda: resolve(request["path"], urlconf=urlconf),
|
||||
)
|
||||
response.resolver_match = SimpleLazyObject(lambda: resolve(request['path']))
|
||||
# Flatten a single context. Not really necessary anymore thanks to the
|
||||
# __getattr__ flattening in ContextList, but has some edge case
|
||||
# backwards compatibility implications.
|
||||
|
||||
@@ -7,62 +7,11 @@ from django.utils.regex_helper import _lazy_re_compile
|
||||
# ASCII whitespace is U+0009 TAB, U+000A LF, U+000C FF, U+000D CR, or U+0020
|
||||
# SPACE.
|
||||
# https://infra.spec.whatwg.org/#ascii-whitespace
|
||||
ASCII_WHITESPACE = _lazy_re_compile(r"[\t\n\f\r ]+")
|
||||
|
||||
# https://html.spec.whatwg.org/#attributes-3
|
||||
BOOLEAN_ATTRIBUTES = {
|
||||
"allowfullscreen",
|
||||
"async",
|
||||
"autofocus",
|
||||
"autoplay",
|
||||
"checked",
|
||||
"controls",
|
||||
"default",
|
||||
"defer ",
|
||||
"disabled",
|
||||
"formnovalidate",
|
||||
"hidden",
|
||||
"ismap",
|
||||
"itemscope",
|
||||
"loop",
|
||||
"multiple",
|
||||
"muted",
|
||||
"nomodule",
|
||||
"novalidate",
|
||||
"open",
|
||||
"playsinline",
|
||||
"readonly",
|
||||
"required",
|
||||
"reversed",
|
||||
"selected",
|
||||
# Attributes for deprecated tags.
|
||||
"truespeed",
|
||||
}
|
||||
ASCII_WHITESPACE = _lazy_re_compile(r'[\t\n\f\r ]+')
|
||||
|
||||
|
||||
def normalize_whitespace(string):
|
||||
return ASCII_WHITESPACE.sub(" ", string)
|
||||
|
||||
|
||||
def normalize_attributes(attributes):
|
||||
normalized = []
|
||||
for name, value in attributes:
|
||||
if name == "class" and value:
|
||||
# Special case handling of 'class' attribute, so that comparisons
|
||||
# of DOM instances are not sensitive to ordering of classes.
|
||||
value = " ".join(
|
||||
sorted(value for value in ASCII_WHITESPACE.split(value) if value)
|
||||
)
|
||||
# Boolean attributes without a value is same as attribute with value
|
||||
# that equals the attributes name. For example:
|
||||
# <input checked> == <input checked="checked">
|
||||
if name in BOOLEAN_ATTRIBUTES:
|
||||
if not value or value == name:
|
||||
value = None
|
||||
elif value is None:
|
||||
value = ""
|
||||
normalized.append((name, value))
|
||||
return normalized
|
||||
return ASCII_WHITESPACE.sub(' ', string)
|
||||
|
||||
|
||||
class Element:
|
||||
@@ -100,14 +49,27 @@ class Element:
|
||||
for i, child in enumerate(self.children):
|
||||
if isinstance(child, str):
|
||||
self.children[i] = child.strip()
|
||||
elif hasattr(child, "finalize"):
|
||||
elif hasattr(child, 'finalize'):
|
||||
child.finalize()
|
||||
|
||||
def __eq__(self, element):
|
||||
if not hasattr(element, "name") or self.name != element.name:
|
||||
if not hasattr(element, 'name') or self.name != element.name:
|
||||
return False
|
||||
if len(self.attributes) != len(element.attributes):
|
||||
return False
|
||||
if self.attributes != element.attributes:
|
||||
return False
|
||||
# attributes without a value is same as attribute with value that
|
||||
# equals the attributes name:
|
||||
# <input checked> == <input checked="checked">
|
||||
for i in range(len(self.attributes)):
|
||||
attr, value = self.attributes[i]
|
||||
other_attr, other_value = element.attributes[i]
|
||||
if value is None:
|
||||
value = attr
|
||||
if other_value is None:
|
||||
other_value = other_attr
|
||||
if attr != other_attr or value != other_value:
|
||||
return False
|
||||
return self.children == element.children
|
||||
|
||||
def __hash__(self):
|
||||
@@ -162,18 +124,18 @@ class Element:
|
||||
return self.children[key]
|
||||
|
||||
def __str__(self):
|
||||
output = "<%s" % self.name
|
||||
output = '<%s' % self.name
|
||||
for key, value in self.attributes:
|
||||
if value is not None:
|
||||
if value:
|
||||
output += ' %s="%s"' % (key, value)
|
||||
else:
|
||||
output += " %s" % key
|
||||
output += ' %s' % key
|
||||
if self.children:
|
||||
output += ">\n"
|
||||
output += "".join(str(c) for c in self.children)
|
||||
output += "\n</%s>" % self.name
|
||||
output += '>\n'
|
||||
output += ''.join(str(c) for c in self.children)
|
||||
output += '\n</%s>' % self.name
|
||||
else:
|
||||
output += ">"
|
||||
output += '>'
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
@@ -185,7 +147,7 @@ class RootElement(Element):
|
||||
super().__init__(None, ())
|
||||
|
||||
def __str__(self):
|
||||
return "".join(str(c) for c in self.children)
|
||||
return ''.join(str(c) for c in self.children)
|
||||
|
||||
|
||||
class HTMLParseError(Exception):
|
||||
@@ -195,23 +157,10 @@ class HTMLParseError(Exception):
|
||||
class Parser(HTMLParser):
|
||||
# https://html.spec.whatwg.org/#void-elements
|
||||
SELF_CLOSING_TAGS = {
|
||||
"area",
|
||||
"base",
|
||||
"br",
|
||||
"col",
|
||||
"embed",
|
||||
"hr",
|
||||
"img",
|
||||
"input",
|
||||
"link",
|
||||
"meta",
|
||||
"param",
|
||||
"source",
|
||||
"track",
|
||||
"wbr",
|
||||
'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta',
|
||||
'param', 'source', 'track', 'wbr',
|
||||
# Deprecated tags
|
||||
"frame",
|
||||
"spacer",
|
||||
'frame', 'spacer',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
@@ -228,9 +177,9 @@ class Parser(HTMLParser):
|
||||
position = self.element_positions[element]
|
||||
if position is None:
|
||||
position = self.getpos()
|
||||
if hasattr(position, "lineno"):
|
||||
if hasattr(position, 'lineno'):
|
||||
position = position.lineno, position.offset
|
||||
return "Line %d, Column %d" % position
|
||||
return 'Line %d, Column %d' % position
|
||||
|
||||
@property
|
||||
def current(self):
|
||||
@@ -245,7 +194,14 @@ class Parser(HTMLParser):
|
||||
self.handle_endtag(tag)
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
attrs = normalize_attributes(attrs)
|
||||
# Special case handling of 'class' attribute, so that comparisons of DOM
|
||||
# instances are not sensitive to ordering of classes.
|
||||
attrs = [
|
||||
(name, ' '.join(sorted(value for value in ASCII_WHITESPACE.split(value) if value)))
|
||||
if name == "class"
|
||||
else (name, value)
|
||||
for name, value in attrs
|
||||
]
|
||||
element = Element(tag, attrs)
|
||||
self.current.append(element)
|
||||
if tag not in self.SELF_CLOSING_TAGS:
|
||||
@@ -254,13 +210,13 @@ class Parser(HTMLParser):
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if not self.open_tags:
|
||||
self.error("Unexpected end tag `%s` (%s)" % (tag, self.format_position()))
|
||||
self.error("Unexpected end tag `%s` (%s)" % (
|
||||
tag, self.format_position()))
|
||||
element = self.open_tags.pop()
|
||||
while element.name != tag:
|
||||
if not self.open_tags:
|
||||
self.error(
|
||||
"Unexpected end tag `%s` (%s)" % (tag, self.format_position())
|
||||
)
|
||||
self.error("Unexpected end tag `%s` (%s)" % (
|
||||
tag, self.format_position()))
|
||||
element = self.open_tags.pop()
|
||||
|
||||
def handle_data(self, data):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -27,9 +27,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||||
"""
|
||||
test_class = super().__new__(cls, name, bases, attrs)
|
||||
# If the test class is either browser-specific or a test base, return it.
|
||||
if test_class.browser or not any(
|
||||
name.startswith("test") and callable(value) for name, value in attrs.items()
|
||||
):
|
||||
if test_class.browser or not any(name.startswith('test') and callable(value) for name, value in attrs.items()):
|
||||
return test_class
|
||||
elif test_class.browsers:
|
||||
# Reuse the created test class to make it browser-specific.
|
||||
@@ -39,7 +37,7 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||||
first_browser = test_class.browsers[0]
|
||||
test_class.browser = first_browser
|
||||
# Listen on an external interface if using a selenium hub.
|
||||
host = test_class.host if not test_class.selenium_hub else "0.0.0.0"
|
||||
host = test_class.host if not test_class.selenium_hub else '0.0.0.0'
|
||||
test_class.host = host
|
||||
test_class.external_host = cls.external_host
|
||||
# Create subclasses for each of the remaining browsers and expose
|
||||
@@ -51,16 +49,16 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||||
"%s%s" % (capfirst(browser), name),
|
||||
(test_class,),
|
||||
{
|
||||
"browser": browser,
|
||||
"host": host,
|
||||
"external_host": cls.external_host,
|
||||
"__module__": test_class.__module__,
|
||||
},
|
||||
'browser': browser,
|
||||
'host': host,
|
||||
'external_host': cls.external_host,
|
||||
'__module__': test_class.__module__,
|
||||
}
|
||||
)
|
||||
setattr(module, browser_test_class.__name__, browser_test_class)
|
||||
return test_class
|
||||
# If no browsers were specified, skip this class (it'll still be discovered).
|
||||
return unittest.skip("No browsers specified.")(test_class)
|
||||
return unittest.skip('No browsers specified.')(test_class)
|
||||
|
||||
@classmethod
|
||||
def import_webdriver(cls, browser):
|
||||
@@ -68,12 +66,13 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||||
|
||||
@classmethod
|
||||
def import_options(cls, browser):
|
||||
return import_string("selenium.webdriver.%s.options.Options" % browser)
|
||||
return import_string('selenium.webdriver.%s.options.Options' % browser)
|
||||
|
||||
@classmethod
|
||||
def get_capability(cls, browser):
|
||||
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
|
||||
|
||||
from selenium.webdriver.common.desired_capabilities import (
|
||||
DesiredCapabilities,
|
||||
)
|
||||
return getattr(DesiredCapabilities, browser.upper())
|
||||
|
||||
def create_options(self):
|
||||
@@ -88,7 +87,6 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||||
def create_webdriver(self):
|
||||
if self.selenium_hub:
|
||||
from selenium import webdriver
|
||||
|
||||
return webdriver.Remote(
|
||||
command_executor=self.selenium_hub,
|
||||
desired_capabilities=self.get_capability(self.browser),
|
||||
@@ -96,14 +94,14 @@ class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||||
return self.import_webdriver(self.browser)(options=self.create_options())
|
||||
|
||||
|
||||
@tag("selenium")
|
||||
@tag('selenium')
|
||||
class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
|
||||
implicit_wait = 10
|
||||
external_host = None
|
||||
|
||||
@classproperty
|
||||
def live_server_url(cls):
|
||||
return "http://%s:%s" % (cls.external_host or cls.host, cls.server_thread.port)
|
||||
return 'http://%s:%s' % (cls.external_host or cls.host, cls.server_thread.port)
|
||||
|
||||
@classproperty
|
||||
def allowed_host(cls):
|
||||
@@ -120,7 +118,7 @@ class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
|
||||
# quit() the WebDriver before attempting to terminate and join the
|
||||
# single-threaded LiveServerThread to avoid a dead lock if the browser
|
||||
# kept a connection alive.
|
||||
if hasattr(cls, "selenium"):
|
||||
if hasattr(cls, 'selenium'):
|
||||
cls.selenium.quit()
|
||||
super()._tearDownClassInternal()
|
||||
|
||||
|
||||
@@ -20,14 +20,13 @@ template_rendered = Signal()
|
||||
# except for cases where the receiver is related to a contrib app.
|
||||
|
||||
# Settings that may not work well when using 'override_settings' (#19031)
|
||||
COMPLEX_OVERRIDE_SETTINGS = {"DATABASES"}
|
||||
COMPLEX_OVERRIDE_SETTINGS = {'DATABASES'}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_cache_handlers(**kwargs):
|
||||
if kwargs["setting"] == "CACHES":
|
||||
if kwargs['setting'] == 'CACHES':
|
||||
from django.core.cache import caches, close_caches
|
||||
|
||||
close_caches()
|
||||
caches._settings = caches.settings = caches.configure_settings(None)
|
||||
caches._connections = Local()
|
||||
@@ -35,41 +34,37 @@ def clear_cache_handlers(**kwargs):
|
||||
|
||||
@receiver(setting_changed)
|
||||
def update_installed_apps(**kwargs):
|
||||
if kwargs["setting"] == "INSTALLED_APPS":
|
||||
if kwargs['setting'] == 'INSTALLED_APPS':
|
||||
# Rebuild any AppDirectoriesFinder instance.
|
||||
from django.contrib.staticfiles.finders import get_finder
|
||||
|
||||
get_finder.cache_clear()
|
||||
# Rebuild management commands cache
|
||||
from django.core.management import get_commands
|
||||
|
||||
get_commands.cache_clear()
|
||||
# Rebuild get_app_template_dirs cache.
|
||||
from django.template.utils import get_app_template_dirs
|
||||
|
||||
get_app_template_dirs.cache_clear()
|
||||
# Rebuild translations cache.
|
||||
from django.utils.translation import trans_real
|
||||
|
||||
trans_real._translations = {}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def update_connections_time_zone(**kwargs):
|
||||
if kwargs["setting"] == "TIME_ZONE":
|
||||
if kwargs['setting'] == 'TIME_ZONE':
|
||||
# Reset process time zone
|
||||
if hasattr(time, "tzset"):
|
||||
if kwargs["value"]:
|
||||
os.environ["TZ"] = kwargs["value"]
|
||||
if hasattr(time, 'tzset'):
|
||||
if kwargs['value']:
|
||||
os.environ['TZ'] = kwargs['value']
|
||||
else:
|
||||
os.environ.pop("TZ", None)
|
||||
os.environ.pop('TZ', None)
|
||||
time.tzset()
|
||||
|
||||
# Reset local time zone cache
|
||||
timezone.get_default_timezone.cache_clear()
|
||||
|
||||
# Reset the database connections' time zone
|
||||
if kwargs["setting"] in {"TIME_ZONE", "USE_TZ"}:
|
||||
if kwargs['setting'] in {'TIME_ZONE', 'USE_TZ'}:
|
||||
for conn in connections.all():
|
||||
try:
|
||||
del conn.timezone
|
||||
@@ -84,19 +79,18 @@ def update_connections_time_zone(**kwargs):
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_routers_cache(**kwargs):
|
||||
if kwargs["setting"] == "DATABASE_ROUTERS":
|
||||
if kwargs['setting'] == 'DATABASE_ROUTERS':
|
||||
router.routers = ConnectionRouter().routers
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def reset_template_engines(**kwargs):
|
||||
if kwargs["setting"] in {
|
||||
"TEMPLATES",
|
||||
"DEBUG",
|
||||
"INSTALLED_APPS",
|
||||
if kwargs['setting'] in {
|
||||
'TEMPLATES',
|
||||
'DEBUG',
|
||||
'INSTALLED_APPS',
|
||||
}:
|
||||
from django.template import engines
|
||||
|
||||
try:
|
||||
del engines.templates
|
||||
except AttributeError:
|
||||
@@ -104,134 +98,112 @@ def reset_template_engines(**kwargs):
|
||||
engines._templates = None
|
||||
engines._engines = {}
|
||||
from django.template.engine import Engine
|
||||
|
||||
Engine.get_default.cache_clear()
|
||||
from django.forms.renderers import get_default_renderer
|
||||
|
||||
get_default_renderer.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_serializers_cache(**kwargs):
|
||||
if kwargs["setting"] == "SERIALIZATION_MODULES":
|
||||
if kwargs['setting'] == 'SERIALIZATION_MODULES':
|
||||
from django.core import serializers
|
||||
|
||||
serializers._serializers = {}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def language_changed(**kwargs):
|
||||
if kwargs["setting"] in {"LANGUAGES", "LANGUAGE_CODE", "LOCALE_PATHS"}:
|
||||
if kwargs['setting'] in {'LANGUAGES', 'LANGUAGE_CODE', 'LOCALE_PATHS'}:
|
||||
from django.utils.translation import trans_real
|
||||
|
||||
trans_real._default = None
|
||||
trans_real._active = Local()
|
||||
if kwargs["setting"] in {"LANGUAGES", "LOCALE_PATHS"}:
|
||||
if kwargs['setting'] in {'LANGUAGES', 'LOCALE_PATHS'}:
|
||||
from django.utils.translation import trans_real
|
||||
|
||||
trans_real._translations = {}
|
||||
trans_real.check_for_language.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def localize_settings_changed(**kwargs):
|
||||
if (
|
||||
kwargs["setting"] in FORMAT_SETTINGS
|
||||
or kwargs["setting"] == "USE_THOUSAND_SEPARATOR"
|
||||
):
|
||||
if kwargs['setting'] in FORMAT_SETTINGS or kwargs['setting'] == 'USE_THOUSAND_SEPARATOR':
|
||||
reset_format_cache()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def file_storage_changed(**kwargs):
|
||||
if kwargs["setting"] == "DEFAULT_FILE_STORAGE":
|
||||
if kwargs['setting'] == 'DEFAULT_FILE_STORAGE':
|
||||
from django.core.files.storage import default_storage
|
||||
|
||||
default_storage._wrapped = empty
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def complex_setting_changed(**kwargs):
|
||||
if kwargs["enter"] and kwargs["setting"] in COMPLEX_OVERRIDE_SETTINGS:
|
||||
if kwargs['enter'] and kwargs['setting'] in COMPLEX_OVERRIDE_SETTINGS:
|
||||
# Considering the current implementation of the signals framework,
|
||||
# this stacklevel shows the line containing the override_settings call.
|
||||
warnings.warn(
|
||||
"Overriding setting %s can lead to unexpected behavior."
|
||||
% kwargs["setting"],
|
||||
stacklevel=6,
|
||||
)
|
||||
warnings.warn("Overriding setting %s can lead to unexpected behavior."
|
||||
% kwargs['setting'], stacklevel=6)
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def root_urlconf_changed(**kwargs):
|
||||
if kwargs["setting"] == "ROOT_URLCONF":
|
||||
if kwargs['setting'] == 'ROOT_URLCONF':
|
||||
from django.urls import clear_url_caches, set_urlconf
|
||||
|
||||
clear_url_caches()
|
||||
set_urlconf(None)
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def static_storage_changed(**kwargs):
|
||||
if kwargs["setting"] in {
|
||||
"STATICFILES_STORAGE",
|
||||
"STATIC_ROOT",
|
||||
"STATIC_URL",
|
||||
if kwargs['setting'] in {
|
||||
'STATICFILES_STORAGE',
|
||||
'STATIC_ROOT',
|
||||
'STATIC_URL',
|
||||
}:
|
||||
from django.contrib.staticfiles.storage import staticfiles_storage
|
||||
|
||||
staticfiles_storage._wrapped = empty
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def static_finders_changed(**kwargs):
|
||||
if kwargs["setting"] in {
|
||||
"STATICFILES_DIRS",
|
||||
"STATIC_ROOT",
|
||||
if kwargs['setting'] in {
|
||||
'STATICFILES_DIRS',
|
||||
'STATIC_ROOT',
|
||||
}:
|
||||
from django.contrib.staticfiles.finders import get_finder
|
||||
|
||||
get_finder.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def auth_password_validators_changed(**kwargs):
|
||||
if kwargs["setting"] == "AUTH_PASSWORD_VALIDATORS":
|
||||
if kwargs['setting'] == 'AUTH_PASSWORD_VALIDATORS':
|
||||
from django.contrib.auth.password_validation import (
|
||||
get_default_password_validators,
|
||||
)
|
||||
|
||||
get_default_password_validators.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def user_model_swapped(**kwargs):
|
||||
if kwargs["setting"] == "AUTH_USER_MODEL":
|
||||
if kwargs['setting'] == 'AUTH_USER_MODEL':
|
||||
apps.clear_cache()
|
||||
try:
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
UserModel = get_user_model()
|
||||
except ImproperlyConfigured:
|
||||
# Some tests set an invalid AUTH_USER_MODEL.
|
||||
pass
|
||||
else:
|
||||
from django.contrib.auth import backends
|
||||
|
||||
backends.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth import forms
|
||||
|
||||
forms.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth.handlers import modwsgi
|
||||
|
||||
modwsgi.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth.management.commands import changepassword
|
||||
|
||||
changepassword.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth import views
|
||||
|
||||
views.UserModel = UserModel
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,6 @@ from django.db.models.options import Options
|
||||
from django.template import Template
|
||||
from django.test.signals import setting_changed, template_rendered
|
||||
from django.urls import get_script_prefix, set_script_prefix
|
||||
from django.utils.deprecation import RemovedInDjango50Warning
|
||||
from django.utils.translation import deactivate
|
||||
|
||||
try:
|
||||
@@ -35,24 +34,15 @@ except ImportError:
|
||||
|
||||
|
||||
__all__ = (
|
||||
"Approximate",
|
||||
"ContextList",
|
||||
"isolate_lru_cache",
|
||||
"get_runner",
|
||||
"CaptureQueriesContext",
|
||||
"ignore_warnings",
|
||||
"isolate_apps",
|
||||
"modify_settings",
|
||||
"override_settings",
|
||||
"override_system_checks",
|
||||
"tag",
|
||||
"requires_tz_support",
|
||||
"setup_databases",
|
||||
"setup_test_environment",
|
||||
"teardown_test_environment",
|
||||
'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner',
|
||||
'CaptureQueriesContext',
|
||||
'ignore_warnings', 'isolate_apps', 'modify_settings', 'override_settings',
|
||||
'override_system_checks', 'tag',
|
||||
'requires_tz_support',
|
||||
'setup_databases', 'setup_test_environment', 'teardown_test_environment',
|
||||
)
|
||||
|
||||
TZ_SUPPORT = hasattr(time, "tzset")
|
||||
TZ_SUPPORT = hasattr(time, 'tzset')
|
||||
|
||||
|
||||
class Approximate:
|
||||
@@ -72,7 +62,6 @@ class ContextList(list):
|
||||
A wrapper that provides direct key access to context items contained
|
||||
in a list of context objects.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, str):
|
||||
for subcontext in self:
|
||||
@@ -120,7 +109,7 @@ def setup_test_environment(debug=None):
|
||||
Perform global pre-test setup, such as installing the instrumented template
|
||||
renderer and setting the email backend to the locmem email backend.
|
||||
"""
|
||||
if hasattr(_TestState, "saved_data"):
|
||||
if hasattr(_TestState, 'saved_data'):
|
||||
# Executing this function twice would overwrite the saved values.
|
||||
raise RuntimeError(
|
||||
"setup_test_environment() was already called and can't be called "
|
||||
@@ -135,13 +124,13 @@ def setup_test_environment(debug=None):
|
||||
|
||||
saved_data.allowed_hosts = settings.ALLOWED_HOSTS
|
||||
# Add the default host of the test client.
|
||||
settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, "testserver"]
|
||||
settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver']
|
||||
|
||||
saved_data.debug = settings.DEBUG
|
||||
settings.DEBUG = debug
|
||||
|
||||
saved_data.email_backend = settings.EMAIL_BACKEND
|
||||
settings.EMAIL_BACKEND = "django.core.mail.backends.locmem.EmailBackend"
|
||||
settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
|
||||
|
||||
saved_data.template_render = Template._render
|
||||
Template._render = instrumented_test_render
|
||||
@@ -167,18 +156,8 @@ def teardown_test_environment():
|
||||
del mail.outbox
|
||||
|
||||
|
||||
def setup_databases(
|
||||
verbosity,
|
||||
interactive,
|
||||
*,
|
||||
time_keeper=None,
|
||||
keepdb=False,
|
||||
debug_sql=False,
|
||||
parallel=0,
|
||||
aliases=None,
|
||||
serialized_aliases=None,
|
||||
**kwargs,
|
||||
):
|
||||
def setup_databases(verbosity, interactive, *, time_keeper=None, keepdb=False, debug_sql=False, parallel=0,
|
||||
aliases=None, **kwargs):
|
||||
"""Create the test databases."""
|
||||
if time_keeper is None:
|
||||
time_keeper = NullTimeKeeper()
|
||||
@@ -197,31 +176,11 @@ def setup_databases(
|
||||
if first_alias is None:
|
||||
first_alias = alias
|
||||
with time_keeper.timed(" Creating '%s'" % alias):
|
||||
# RemovedInDjango50Warning: when the deprecation ends,
|
||||
# replace with:
|
||||
# serialize_alias = (
|
||||
# serialized_aliases is None
|
||||
# or alias in serialized_aliases
|
||||
# )
|
||||
try:
|
||||
serialize_alias = connection.settings_dict["TEST"]["SERIALIZE"]
|
||||
except KeyError:
|
||||
serialize_alias = (
|
||||
serialized_aliases is None or alias in serialized_aliases
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"The SERIALIZE test database setting is "
|
||||
"deprecated as it can be inferred from the "
|
||||
"TestCase/TransactionTestCase.databases that "
|
||||
"enable the serialized_rollback feature.",
|
||||
category=RemovedInDjango50Warning,
|
||||
)
|
||||
connection.creation.create_test_db(
|
||||
verbosity=verbosity,
|
||||
autoclobber=not interactive,
|
||||
keepdb=keepdb,
|
||||
serialize=serialize_alias,
|
||||
serialize=connection.settings_dict['TEST'].get('SERIALIZE', True),
|
||||
)
|
||||
if parallel > 1:
|
||||
for index in range(parallel):
|
||||
@@ -233,15 +192,12 @@ def setup_databases(
|
||||
)
|
||||
# Configure all other connections as mirrors of the first one
|
||||
else:
|
||||
connections[alias].creation.set_as_test_mirror(
|
||||
connections[first_alias].settings_dict
|
||||
)
|
||||
connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict)
|
||||
|
||||
# Configure the test mirrors.
|
||||
for alias, mirror_alias in mirrored_aliases.items():
|
||||
connections[alias].creation.set_as_test_mirror(
|
||||
connections[mirror_alias].settings_dict
|
||||
)
|
||||
connections[mirror_alias].settings_dict)
|
||||
|
||||
if debug_sql:
|
||||
for alias in connections:
|
||||
@@ -250,27 +206,6 @@ def setup_databases(
|
||||
return old_names
|
||||
|
||||
|
||||
def iter_test_cases(tests):
|
||||
"""
|
||||
Return an iterator over a test suite's unittest.TestCase objects.
|
||||
|
||||
The tests argument can also be an iterable of TestCase objects.
|
||||
"""
|
||||
for test in tests:
|
||||
if isinstance(test, str):
|
||||
# Prevent an unfriendly RecursionError that can happen with
|
||||
# strings.
|
||||
raise TypeError(
|
||||
f"Test {test!r} must be a test case or test suite not string "
|
||||
f"(was found in {tests!r})."
|
||||
)
|
||||
if isinstance(test, TestCase):
|
||||
yield test
|
||||
else:
|
||||
# Otherwise, assume it is a test suite.
|
||||
yield from iter_test_cases(test)
|
||||
|
||||
|
||||
def dependency_ordered(test_databases, dependencies):
|
||||
"""
|
||||
Reorder test_databases into an order that honors the dependencies
|
||||
@@ -334,18 +269,18 @@ def get_unique_databases_and_mirrors(aliases=None):
|
||||
|
||||
for alias in connections:
|
||||
connection = connections[alias]
|
||||
test_settings = connection.settings_dict["TEST"]
|
||||
test_settings = connection.settings_dict['TEST']
|
||||
|
||||
if test_settings["MIRROR"]:
|
||||
if test_settings['MIRROR']:
|
||||
# If the database is marked as a test mirror, save the alias.
|
||||
mirrored_aliases[alias] = test_settings["MIRROR"]
|
||||
mirrored_aliases[alias] = test_settings['MIRROR']
|
||||
elif alias in aliases:
|
||||
# Store a tuple with DB parameters that uniquely identify it.
|
||||
# If we have two aliases with the same values for that tuple,
|
||||
# we only need to create the test database once.
|
||||
item = test_databases.setdefault(
|
||||
connection.creation.test_db_signature(),
|
||||
(connection.settings_dict["NAME"], []),
|
||||
(connection.settings_dict['NAME'], []),
|
||||
)
|
||||
# The default database must be the first because data migrations
|
||||
# use the default alias by default.
|
||||
@@ -354,16 +289,11 @@ def get_unique_databases_and_mirrors(aliases=None):
|
||||
else:
|
||||
item[1].append(alias)
|
||||
|
||||
if "DEPENDENCIES" in test_settings:
|
||||
dependencies[alias] = test_settings["DEPENDENCIES"]
|
||||
if 'DEPENDENCIES' in test_settings:
|
||||
dependencies[alias] = test_settings['DEPENDENCIES']
|
||||
else:
|
||||
if (
|
||||
alias != DEFAULT_DB_ALIAS
|
||||
and connection.creation.test_db_signature() != default_sig
|
||||
):
|
||||
dependencies[alias] = test_settings.get(
|
||||
"DEPENDENCIES", [DEFAULT_DB_ALIAS]
|
||||
)
|
||||
if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig:
|
||||
dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS])
|
||||
|
||||
test_databases = dict(dependency_ordered(test_databases.items(), dependencies))
|
||||
return test_databases, mirrored_aliases
|
||||
@@ -385,12 +315,12 @@ def teardown_databases(old_config, verbosity, parallel=0, keepdb=False):
|
||||
|
||||
def get_runner(settings, test_runner_class=None):
|
||||
test_runner_class = test_runner_class or settings.TEST_RUNNER
|
||||
test_path = test_runner_class.split(".")
|
||||
test_path = test_runner_class.split('.')
|
||||
# Allow for relative paths
|
||||
if len(test_path) > 1:
|
||||
test_module_name = ".".join(test_path[:-1])
|
||||
test_module_name = '.'.join(test_path[:-1])
|
||||
else:
|
||||
test_module_name = "."
|
||||
test_module_name = '.'
|
||||
test_module = __import__(test_module_name, {}, {}, test_path[-1])
|
||||
return getattr(test_module, test_path[-1])
|
||||
|
||||
@@ -407,7 +337,6 @@ class TestContextDecorator:
|
||||
`kwarg_name`: keyword argument passing the return value of enable() if
|
||||
used as a function decorator.
|
||||
"""
|
||||
|
||||
def __init__(self, attr_name=None, kwarg_name=None):
|
||||
self.attr_name = attr_name
|
||||
self.kwarg_name = kwarg_name
|
||||
@@ -437,7 +366,7 @@ class TestContextDecorator:
|
||||
|
||||
cls.setUp = setUp
|
||||
return cls
|
||||
raise TypeError("Can only decorate subclasses of unittest.TestCase")
|
||||
raise TypeError('Can only decorate subclasses of unittest.TestCase')
|
||||
|
||||
def decorate_callable(self, func):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@@ -449,16 +378,13 @@ class TestContextDecorator:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
with self as context:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
def __call__(self, decorated):
|
||||
@@ -466,7 +392,7 @@ class TestContextDecorator:
|
||||
return self.decorate_class(decorated)
|
||||
elif callable(decorated):
|
||||
return self.decorate_callable(decorated)
|
||||
raise TypeError("Cannot decorate object of type %s" % type(decorated))
|
||||
raise TypeError('Cannot decorate object of type %s' % type(decorated))
|
||||
|
||||
|
||||
class override_settings(TestContextDecorator):
|
||||
@@ -476,7 +402,6 @@ class override_settings(TestContextDecorator):
|
||||
with the ``with`` statement. In either event, entering/exiting are called
|
||||
before and after, respectively, the function/block is executed.
|
||||
"""
|
||||
|
||||
enable_exception = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -486,9 +411,9 @@ class override_settings(TestContextDecorator):
|
||||
def enable(self):
|
||||
# Keep this code at the beginning to leave the settings unchanged
|
||||
# in case it raises an exception because INSTALLED_APPS is invalid.
|
||||
if "INSTALLED_APPS" in self.options:
|
||||
if 'INSTALLED_APPS' in self.options:
|
||||
try:
|
||||
apps.set_installed_apps(self.options["INSTALLED_APPS"])
|
||||
apps.set_installed_apps(self.options['INSTALLED_APPS'])
|
||||
except Exception:
|
||||
apps.unset_installed_apps()
|
||||
raise
|
||||
@@ -501,16 +426,14 @@ class override_settings(TestContextDecorator):
|
||||
try:
|
||||
setting_changed.send(
|
||||
sender=settings._wrapped.__class__,
|
||||
setting=key,
|
||||
value=new_value,
|
||||
enter=True,
|
||||
setting=key, value=new_value, enter=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.enable_exception = exc
|
||||
self.disable()
|
||||
|
||||
def disable(self):
|
||||
if "INSTALLED_APPS" in self.options:
|
||||
if 'INSTALLED_APPS' in self.options:
|
||||
apps.unset_installed_apps()
|
||||
settings._wrapped = self.wrapped
|
||||
del self.wrapped
|
||||
@@ -519,9 +442,7 @@ class override_settings(TestContextDecorator):
|
||||
new_value = getattr(settings, key, None)
|
||||
responses_for_setting = setting_changed.send_robust(
|
||||
sender=settings._wrapped.__class__,
|
||||
setting=key,
|
||||
value=new_value,
|
||||
enter=False,
|
||||
setting=key, value=new_value, enter=False,
|
||||
)
|
||||
responses.extend(responses_for_setting)
|
||||
if self.enable_exception is not None:
|
||||
@@ -544,12 +465,10 @@ class override_settings(TestContextDecorator):
|
||||
|
||||
def decorate_class(self, cls):
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
if not issubclass(cls, SimpleTestCase):
|
||||
raise ValueError(
|
||||
"Only subclasses of Django SimpleTestCase can be decorated "
|
||||
"with override_settings"
|
||||
)
|
||||
"with override_settings")
|
||||
self.save_options(cls)
|
||||
return cls
|
||||
|
||||
@@ -559,7 +478,6 @@ class modify_settings(override_settings):
|
||||
Like override_settings, but makes it possible to append, prepend, or remove
|
||||
items instead of redefining the entire list.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if args:
|
||||
# Hack used when instantiating from SimpleTestCase.setUpClass.
|
||||
@@ -575,9 +493,8 @@ class modify_settings(override_settings):
|
||||
test_func._modified_settings = self.operations
|
||||
else:
|
||||
# Duplicate list to prevent subclasses from altering their parent.
|
||||
test_func._modified_settings = (
|
||||
list(test_func._modified_settings) + self.operations
|
||||
)
|
||||
test_func._modified_settings = list(
|
||||
test_func._modified_settings) + self.operations
|
||||
|
||||
def enable(self):
|
||||
self.options = {}
|
||||
@@ -592,11 +509,11 @@ class modify_settings(override_settings):
|
||||
# items my be a single value or an iterable.
|
||||
if isinstance(items, str):
|
||||
items = [items]
|
||||
if action == "append":
|
||||
if action == 'append':
|
||||
value = value + [item for item in items if item not in value]
|
||||
elif action == "prepend":
|
||||
elif action == 'prepend':
|
||||
value = [item for item in items if item not in value] + value
|
||||
elif action == "remove":
|
||||
elif action == 'remove':
|
||||
value = [item for item in value if item not in items]
|
||||
else:
|
||||
raise ValueError("Unsupported action: %s" % action)
|
||||
@@ -610,10 +527,8 @@ class override_system_checks(TestContextDecorator):
|
||||
Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app,
|
||||
you also need to exclude its system checks.
|
||||
"""
|
||||
|
||||
def __init__(self, new_checks, deployment_checks=None):
|
||||
from django.core.checks.registry import registry
|
||||
|
||||
self.registry = registry
|
||||
self.new_checks = new_checks
|
||||
self.deployment_checks = deployment_checks
|
||||
@@ -623,12 +538,12 @@ class override_system_checks(TestContextDecorator):
|
||||
self.old_checks = self.registry.registered_checks
|
||||
self.registry.registered_checks = set()
|
||||
for check in self.new_checks:
|
||||
self.registry.register(check, *getattr(check, "tags", ()))
|
||||
self.registry.register(check, *getattr(check, 'tags', ()))
|
||||
self.old_deployment_checks = self.registry.deployment_checks
|
||||
if self.deployment_checks is not None:
|
||||
self.registry.deployment_checks = set()
|
||||
for check in self.deployment_checks:
|
||||
self.registry.register(check, *getattr(check, "tags", ()), deploy=True)
|
||||
self.registry.register(check, *getattr(check, 'tags', ()), deploy=True)
|
||||
|
||||
def disable(self):
|
||||
self.registry.registered_checks = self.old_checks
|
||||
@@ -644,18 +559,18 @@ def compare_xml(want, got):
|
||||
|
||||
Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py
|
||||
"""
|
||||
_norm_whitespace_re = re.compile(r"[ \t\n][ \t\n]+")
|
||||
_norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
|
||||
|
||||
def norm_whitespace(v):
|
||||
return _norm_whitespace_re.sub(" ", v)
|
||||
return _norm_whitespace_re.sub(' ', v)
|
||||
|
||||
def child_text(element):
|
||||
return "".join(
|
||||
c.data for c in element.childNodes if c.nodeType == Node.TEXT_NODE
|
||||
)
|
||||
return ''.join(c.data for c in element.childNodes
|
||||
if c.nodeType == Node.TEXT_NODE)
|
||||
|
||||
def children(element):
|
||||
return [c for c in element.childNodes if c.nodeType == Node.ELEMENT_NODE]
|
||||
return [c for c in element.childNodes
|
||||
if c.nodeType == Node.ELEMENT_NODE]
|
||||
|
||||
def norm_child_text(element):
|
||||
return norm_whitespace(child_text(element))
|
||||
@@ -674,9 +589,7 @@ def compare_xml(want, got):
|
||||
got_children = children(got_element)
|
||||
if len(want_children) != len(got_children):
|
||||
return False
|
||||
return all(
|
||||
check_element(want, got) for want, got in zip(want_children, got_children)
|
||||
)
|
||||
return all(check_element(want, got) for want, got in zip(want_children, got_children))
|
||||
|
||||
def first_node(document):
|
||||
for node in document.childNodes:
|
||||
@@ -687,13 +600,13 @@ def compare_xml(want, got):
|
||||
):
|
||||
return node
|
||||
|
||||
want = want.strip().replace("\\n", "\n")
|
||||
got = got.strip().replace("\\n", "\n")
|
||||
want = want.strip().replace('\\n', '\n')
|
||||
got = got.strip().replace('\\n', '\n')
|
||||
|
||||
# If the string is not a complete xml document, we may need to add a
|
||||
# root element. This allow us to compare fragments, like "<foo/><bar/>"
|
||||
if not want.startswith("<?xml"):
|
||||
wrapper = "<root>%s</root>"
|
||||
if not want.startswith('<?xml'):
|
||||
wrapper = '<root>%s</root>'
|
||||
want = wrapper % want
|
||||
got = wrapper % got
|
||||
|
||||
@@ -708,7 +621,6 @@ class CaptureQueriesContext:
|
||||
"""
|
||||
Context manager that captures queries executed by the specified connection.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
@@ -723,7 +635,7 @@ class CaptureQueriesContext:
|
||||
|
||||
@property
|
||||
def captured_queries(self):
|
||||
return self.connection.queries[self.initial_queries : self.final_queries]
|
||||
return self.connection.queries[self.initial_queries:self.final_queries]
|
||||
|
||||
def __enter__(self):
|
||||
self.force_debug_cursor = self.connection.force_debug_cursor
|
||||
@@ -747,7 +659,7 @@ class CaptureQueriesContext:
|
||||
class ignore_warnings(TestContextDecorator):
|
||||
def __init__(self, **kwargs):
|
||||
self.ignore_kwargs = kwargs
|
||||
if "message" in self.ignore_kwargs or "module" in self.ignore_kwargs:
|
||||
if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs:
|
||||
self.filter_func = warnings.filterwarnings
|
||||
else:
|
||||
self.filter_func = warnings.simplefilter
|
||||
@@ -756,7 +668,7 @@ class ignore_warnings(TestContextDecorator):
|
||||
def enable(self):
|
||||
self.catch_warnings = warnings.catch_warnings()
|
||||
self.catch_warnings.__enter__()
|
||||
self.filter_func("ignore", **self.ignore_kwargs)
|
||||
self.filter_func('ignore', **self.ignore_kwargs)
|
||||
|
||||
def disable(self):
|
||||
self.catch_warnings.__exit__(*sys.exc_info())
|
||||
@@ -770,7 +682,7 @@ class ignore_warnings(TestContextDecorator):
|
||||
requires_tz_support = skipUnless(
|
||||
TZ_SUPPORT,
|
||||
"This test relies on the ability to run a program in an arbitrary "
|
||||
"time zone, but your operating system isn't able to do that.",
|
||||
"time zone, but your operating system isn't able to do that."
|
||||
)
|
||||
|
||||
|
||||
@@ -813,9 +725,9 @@ def captured_output(stream_name):
|
||||
def captured_stdout():
|
||||
"""Capture the output of sys.stdout:
|
||||
|
||||
with captured_stdout() as stdout:
|
||||
print("hello")
|
||||
self.assertEqual(stdout.getvalue(), "hello\n")
|
||||
with captured_stdout() as stdout:
|
||||
print("hello")
|
||||
self.assertEqual(stdout.getvalue(), "hello\n")
|
||||
"""
|
||||
return captured_output("stdout")
|
||||
|
||||
@@ -823,9 +735,9 @@ def captured_stdout():
|
||||
def captured_stderr():
|
||||
"""Capture the output of sys.stderr:
|
||||
|
||||
with captured_stderr() as stderr:
|
||||
print("hello", file=sys.stderr)
|
||||
self.assertEqual(stderr.getvalue(), "hello\n")
|
||||
with captured_stderr() as stderr:
|
||||
print("hello", file=sys.stderr)
|
||||
self.assertEqual(stderr.getvalue(), "hello\n")
|
||||
"""
|
||||
return captured_output("stderr")
|
||||
|
||||
@@ -833,12 +745,12 @@ def captured_stderr():
|
||||
def captured_stdin():
|
||||
"""Capture the input to sys.stdin:
|
||||
|
||||
with captured_stdin() as stdin:
|
||||
stdin.write('hello\n')
|
||||
stdin.seek(0)
|
||||
# call test code that consumes from sys.stdin
|
||||
captured = input()
|
||||
self.assertEqual(captured, "hello")
|
||||
with captured_stdin() as stdin:
|
||||
stdin.write('hello\n')
|
||||
stdin.seek(0)
|
||||
# call test code that consumes from sys.stdin
|
||||
captured = input()
|
||||
self.assertEqual(captured, "hello")
|
||||
"""
|
||||
return captured_output("stdin")
|
||||
|
||||
@@ -866,24 +778,18 @@ def require_jinja2(test_func):
|
||||
Django template engine for a test or skip it if Jinja2 isn't available.
|
||||
"""
|
||||
test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func)
|
||||
return override_settings(
|
||||
TEMPLATES=[
|
||||
{
|
||||
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||
"APP_DIRS": True,
|
||||
},
|
||||
{
|
||||
"BACKEND": "django.template.backends.jinja2.Jinja2",
|
||||
"APP_DIRS": True,
|
||||
"OPTIONS": {"keep_trailing_newline": True},
|
||||
},
|
||||
]
|
||||
)(test_func)
|
||||
return override_settings(TEMPLATES=[{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'APP_DIRS': True,
|
||||
}, {
|
||||
'BACKEND': 'django.template.backends.jinja2.Jinja2',
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {'keep_trailing_newline': True},
|
||||
}])(test_func)
|
||||
|
||||
|
||||
class override_script_prefix(TestContextDecorator):
|
||||
"""Decorator or context manager to temporary override the script prefix."""
|
||||
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
super().__init__()
|
||||
@@ -901,9 +807,8 @@ class LoggingCaptureMixin:
|
||||
Capture the output from the 'django' logger and store it on the class's
|
||||
logger_output attribute.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.logger = logging.getLogger("django")
|
||||
self.logger = logging.getLogger('django')
|
||||
self.old_stream = self.logger.handlers[0].stream
|
||||
self.logger_output = StringIO()
|
||||
self.logger.handlers[0].stream = self.logger_output
|
||||
@@ -928,7 +833,6 @@ class isolate_apps(TestContextDecorator):
|
||||
`kwarg_name`: keyword argument passing the isolated registry if used as a
|
||||
function decorator.
|
||||
"""
|
||||
|
||||
def __init__(self, *installed_apps, **kwargs):
|
||||
self.installed_apps = installed_apps
|
||||
super().__init__(**kwargs)
|
||||
@@ -936,11 +840,11 @@ class isolate_apps(TestContextDecorator):
|
||||
def enable(self):
|
||||
self.old_apps = Options.default_apps
|
||||
apps = Apps(self.installed_apps)
|
||||
setattr(Options, "default_apps", apps)
|
||||
setattr(Options, 'default_apps', apps)
|
||||
return apps
|
||||
|
||||
def disable(self):
|
||||
setattr(Options, "default_apps", self.old_apps)
|
||||
setattr(Options, 'default_apps', self.old_apps)
|
||||
|
||||
|
||||
class TimeKeeper:
|
||||
@@ -960,7 +864,7 @@ class TimeKeeper:
|
||||
def print_results(self):
|
||||
for name, end_times in self.records.items():
|
||||
for record_time in end_times:
|
||||
record = "%s took %.3fs" % (name, record_time)
|
||||
record = '%s took %.3fs' % (name, record_time)
|
||||
sys.stderr.write(record + os.linesep)
|
||||
|
||||
|
||||
@@ -975,14 +879,12 @@ class NullTimeKeeper:
|
||||
|
||||
def tag(*tags):
|
||||
"""Decorator to add tags to a test class or method."""
|
||||
|
||||
def decorator(obj):
|
||||
if hasattr(obj, "tags"):
|
||||
if hasattr(obj, 'tags'):
|
||||
obj.tags = obj.tags.union(tags)
|
||||
else:
|
||||
setattr(obj, "tags", set(tags))
|
||||
setattr(obj, 'tags', set(tags))
|
||||
return obj
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user