Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 5 additions & 30 deletions tests/federation/test_federation_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@

from tests import unittest
from tests.test_utils import SMALL_PNG
from tests.unittest import override_config


class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
Expand All @@ -65,9 +64,6 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
)
self.media_repo = hs.get_media_repository()

@override_config(
{"experimental_features": {"msc3916_authenticated_media_enabled": True}}
)
def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
Expand All @@ -82,7 +78,7 @@ def test_file_download(self) -> None:
# test with a text file
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
)
self.pump()
self.assertEqual(200, channel.code)
Expand All @@ -106,7 +102,8 @@ def test_file_download(self) -> None:

# check that the text file and expected value exist
found_file = any(
"\r\nContent-Type: text/plain\r\n\r\nfile_to_stream" in field
"\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream"
in field
for field in stripped
)
self.assertTrue(found_file)
Expand All @@ -124,7 +121,7 @@ def test_file_download(self) -> None:
# test with an image file
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
)
self.pump()
self.assertEqual(200, channel.code)
Expand All @@ -149,25 +146,3 @@ def test_file_download(self) -> None:
# check that the png file exists and matches what was uploaded
found_file = any(SMALL_PNG in field for field in stripped_bytes)
self.assertTrue(found_file)

@override_config(
{"experimental_features": {"msc3916_authenticated_media_enabled": False}}
)
def test_disable_config(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
self.media_repo.create_content(
"text/plain",
"test_upload",
content,
46,
UserID.from_string("@user_id:whatever.org"),
)
)
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
)
self.pump()
self.assertEqual(404, channel.code)
self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED")
143 changes: 140 additions & 3 deletions tests/http/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,155 @@
BlocklistingAgentWrapper,
BlocklistingReactorWrapper,
BodyExceededMaxSize,
MultipartResponse,
_DiscardBodyWithMaxSizeProtocol,
_MultipartParserProtocol,
read_body_with_max_size,
read_multipart_response,
)

from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase


class ReadMultipartResponseTests(TestCase):
data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"

redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"

def _build_multipart_response(
self, response_length: Union[int, str], max_length: int
) -> Tuple[
BytesIO,
"Deferred[MultipartResponse]",
_MultipartParserProtocol,
]:
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=response_length)
result = BytesIO()
boundary = "6067d4698f8d40a0a794ea7d7379d53a"
deferred = read_multipart_response(response, result, boundary, max_length)

# Fish the protocol out of the response.
protocol = response.deliverBody.call_args[0][0]
protocol.transport = Mock()

return result, deferred, protocol

def _assert_error(
self,
deferred: "Deferred[MultipartResponse]",
protocol: _MultipartParserProtocol,
) -> None:
"""Ensure that the expected error is received."""
assert isinstance(deferred.result, Failure)
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
assert protocol.transport is not None
# type-ignore: presumably abortConnection has been replaced with a Mock.
protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]

def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None:
"""Ensure that the error in the Deferred is handled gracefully."""
called = [False]

def errback(f: Failure) -> None:
called[0] = True

deferred.addErrback(errback)
self.assertTrue(called[0])

def test_parse_file(self) -> None:
"""
Check that a multipart response containing a file is properly parsed
into the json/file parts, and the json and file are properly captured
"""
result, deferred, protocol = self._build_multipart_response(249, 250)

# Start sending data.
protocol.dataReceived(self.data1)
protocol.dataReceived(self.data2)
# Close the connection.
protocol.connectionLost(Failure(ResponseDone()))

multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]

self.assertEqual(multipart_response.json, b"{}")
self.assertEqual(result.getvalue(), b"file_to_stream")
self.assertEqual(multipart_response.length, len(b"file_to_stream"))
self.assertEqual(multipart_response.content_type, b"text/plain")
self.assertEqual(
multipart_response.disposition, b"inline; filename=test_upload"
)

def test_parse_redirect(self) -> None:
"""
check that a multipart response containing a redirect is properly parsed and redirect url is
returned
"""
result, deferred, protocol = self._build_multipart_response(249, 250)

# Start sending data.
protocol.dataReceived(self.redirect_data)
# Close the connection.
protocol.connectionLost(Failure(ResponseDone()))

multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]

self.assertEqual(multipart_response.json, b"{}")
self.assertEqual(result.getvalue(), b"")
self.assertEqual(
multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt"
)

def test_too_large(self) -> None:
"""A response which is too large raises an exception."""
result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)

# Start sending data.
protocol.dataReceived(self.data1)

self.assertEqual(result.getvalue(), b"file_")
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)

def test_additional_data(self) -> None:
"""A connection can receive data after being closed."""
result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)

# Start sending data.
protocol.dataReceived(self.data1)
self._assert_error(deferred, protocol)

# More data might have come in.
protocol.dataReceived(self.data2)

self.assertEqual(result.getvalue(), b"file_")
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)

def test_content_length(self) -> None:
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
result, deferred, protocol = self._build_multipart_response(250, 1)

# Deferred shouldn't be called yet.
self.assertFalse(deferred.called)

# Start sending data.
protocol.dataReceived(self.data1)
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)

# The data is never consumed.
self.assertEqual(result.getvalue(), b"")


class ReadBodyWithMaxSizeTests(TestCase):
def _build_response(
self, length: Union[int, str] = UNKNOWN_LENGTH
) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[
BytesIO,
"Deferred[int]",
_DiscardBodyWithMaxSizeProtocol,
]:
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=length)
result = BytesIO()
Expand Down
14 changes: 7 additions & 7 deletions tests/media/test_media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_ensure_media_is_in_local_cache(self) -> None:


@attr.s(auto_attribs=True, slots=True, frozen=True)
class _TestImage:
class TestImage:
"""An image for testing thumbnailing with the expected results

Attributes:
Expand Down Expand Up @@ -158,7 +158,7 @@ class _TestImage:
is_inline: bool = True


small_png = _TestImage(
small_png = TestImage(
SMALL_PNG,
b"image/png",
b".png",
Expand All @@ -175,7 +175,7 @@ class _TestImage:
),
)

small_png_with_transparency = _TestImage(
small_png_with_transparency = TestImage(
unhexlify(
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
Expand All @@ -188,23 +188,23 @@ class _TestImage:
# different versions of Pillow.
)

small_lossless_webp = _TestImage(
small_lossless_webp = TestImage(
unhexlify(
b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
),
b"image/webp",
b".webp",
)

empty_file = _TestImage(
empty_file = TestImage(
b"",
b"image/gif",
b".gif",
expected_found=False,
unable_to_thumbnail=True,
)

SVG = _TestImage(
SVG = TestImage(
b"""<?xml version="1.0"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
Expand Down Expand Up @@ -236,7 +236,7 @@ class _TestImage:
@parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
class MediaRepoTests(unittest.HomeserverTestCase):
servlets = [media.register_servlets]
test_image: ClassVar[_TestImage]
test_image: ClassVar[TestImage]
hijack_auth = True
user_id = "@test:user"
url: ClassVar[str]
Expand Down
Loading