Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17365.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md) by adding _matrix/client/v1/media/download endpoint.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,6 @@ ignore_missing_imports = True
# https://github.com/twisted/treq/pull/366
[mypy-treq.*]
ignore_missing_imports = True

[mypy-multipart.*]
ignore_missing_imports = True
18 changes: 16 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ pydantic = ">=1.7.4, <3"
# needed.
setuptools_rust = ">=1.3"

# This is used for parsing multipart responses
python-multipart = ">=0.0.9"

# Optional Dependencies
# ---------------------
Expand Down
46 changes: 46 additions & 0 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,52 @@ def filter_user_id(user_id: str) -> bool:

return filtered_statuses, filtered_failures

async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Union[
Tuple[int, Dict[bytes, List[bytes]], bytes],
Tuple[int, Dict[bytes, List[bytes]]],
]:
try:
return await self.transport_layer.federation_download_media(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise

logger.debug(
"Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
destination,
media_id,
)

return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)

async def download_media(
self,
destination: str,
Expand Down
25 changes: 23 additions & 2 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,6 @@ async def download_media_r0(
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
Expand Down Expand Up @@ -852,7 +851,6 @@ async def download_media_v3(
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"

return await self.client.get_file(
destination,
path,
Expand All @@ -873,6 +871,29 @@ async def download_media_v3(
ip_address=ip_address,
)

async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
path = f"/_matrix/federation/v1/media/download/{media_id}"
Comment thread
anoadragon453 marked this conversation as resolved.
return await self.client.federation_get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
"timeout_ms": str(max_timeout_ms),
},
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)


def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""
Expand Down
9 changes: 3 additions & 6 deletions synapse/federation/transport/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet,
FederationMediaDownloadServlet,
FederationUnstableClientKeysClaimServlet,
FederationUnstableMediaDownloadServlet,
)
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
Expand Down Expand Up @@ -316,11 +316,8 @@ def register_servlets(
):
continue

if servletclass == FederationUnstableMediaDownloadServlet:
if (
not hs.config.server.enable_media_repo
or not hs.config.experimental.msc3916_authenticated_media_enabled
):
if servletclass == FederationMediaDownloadServlet:
if not hs.config.server.enable_media_repo:
continue

servletclass(
Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/transport/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ async def new_func(
return None
if (
func.__self__.__class__.__name__ # type: ignore
== "FederationUnstableMediaDownloadServlet"
== "FederationMediaDownloadServlet"
):
response = await func(
origin, content, request, *args, **kwargs
Expand All @@ -374,7 +374,7 @@ async def new_func(
else:
if (
func.__self__.__class__.__name__ # type: ignore
== "FederationUnstableMediaDownloadServlet"
== "FederationMediaDownloadServlet"
):
response = await func(
origin, content, request, *args, **kwargs
Expand Down
5 changes: 2 additions & 3 deletions synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,15 +790,14 @@ async def on_POST(
return 200, {"account_statuses": statuses, "failures": failures}


class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
class FederationMediaDownloadServlet(BaseFederationServerServlet):
"""
Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
a multipart/mixed response consisting of a JSON object and the requested media
item. This endpoint only returns local media.
"""

PATH = "/media/download/(?P<media_id>[^/]*)"
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
RATELIMIT = True

def __init__(
Expand Down Expand Up @@ -858,5 +857,5 @@ async def on_GET(
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
FederationAccountStatusServlet,
FederationUnstableMediaDownloadServlet,
FederationMediaDownloadServlet,
)
152 changes: 152 additions & 0 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
Union,
)

import attr
import multipart
import treq
from canonicaljson import encode_canonical_json
from netaddr import AddrFormatError, IPAddress, IPSet
Expand Down Expand Up @@ -1006,6 +1008,130 @@ def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()


@attr.s(auto_attribs=True, slots=True)
class MultipartResponse:
"""
A small class to hold parsed values of a multipart response.
"""

json: bytes = b"{}"
length: Optional[int] = None
content_type: Optional[bytes] = None
disposition: Optional[bytes] = None
url: Optional[bytes] = None


class _MultipartParserProtocol(protocol.Protocol):
"""
Protocol to read and parse a MSC3916 multipart/mixed response
"""

transport: Optional[ITCPTransport] = None

def __init__(
self,
stream: ByteWriteable,
deferred: defer.Deferred,
boundary: str,
max_length: Optional[int],
) -> None:
self.stream = stream
self.deferred = deferred
self.boundary = boundary
self.max_length = max_length
self.parser = None
self.multipart_response = MultipartResponse()
self.has_redirect = False
self.in_json = False
self.json_done = False
self.file_length = 0
self.total_length = 0
self.in_disposition = False
self.in_content_type = False

def dataReceived(self, incoming_data: bytes) -> None:
if self.deferred.called:
return

# we don't have a parser yet, instantiate it
if not self.parser:

def on_header_field(data: bytes, start: int, end: int) -> None:
if data[start:end] == b"Location":
Comment thread
H-Shay marked this conversation as resolved.
self.has_redirect = True
if data[start:end] == b"Content-Disposition":
self.in_disposition = True
if data[start:end] == b"Content-Type":
self.in_content_type = True

def on_header_value(data: bytes, start: int, end: int) -> None:
# the first header should be content-type for application/json
if not self.in_json and not self.json_done:
assert data[start:end] == b"application/json"
self.in_json = True
elif self.has_redirect:
self.multipart_response.url = data[start:end]
elif self.in_content_type:
self.multipart_response.content_type = data[start:end]
self.in_content_type = False
elif self.in_disposition:
self.multipart_response.disposition = data[start:end]
self.in_disposition = False

def on_part_data(data: bytes, start: int, end: int) -> None:
# we've seen json header but haven't written the json data
if self.in_json and not self.json_done:
self.multipart_response.json = data[start:end]
self.json_done = True
# we have a redirect header rather than a file, and have already captured it
elif self.has_redirect:
return
# otherwise we are in the file part
else:
logger.info("Writing multipart file data to stream")
try:
self.stream.write(data[start:end])
except Exception as e:
logger.warning(
f"Exception encountered writing file data to stream: {e}"
)
self.deferred.errback()
self.file_length += end - start

callbacks = {
"on_header_field": on_header_field,
"on_header_value": on_header_value,
"on_part_data": on_part_data,
}
self.parser = multipart.MultipartParser(self.boundary, callbacks)

self.total_length += len(incoming_data)
if self.max_length is not None and self.total_length >= self.max_length:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()

try:
self.parser.write(incoming_data) # type: ignore[attr-defined]
except Exception as e:
logger.warning(f"Exception writing to multipart parser: {e}")
self.deferred.errback()
return

def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return

if reason.check(ResponseDone):
self.multipart_response.length = self.file_length
self.deferred.callback(self.multipart_response)
else:
self.deferred.errback(reason)


class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""

Expand Down Expand Up @@ -1091,6 +1217,32 @@ def read_body_with_max_size(
return d


def read_multipart_response(
response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
) -> "defer.Deferred[MultipartResponse]":
"""
Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
the stream passed in and returning a deferred resolving to a MultipartResponse

Args:
response: The HTTP response to read from.
stream: The file-object to write to.
boundary: the multipart/mixed boundary string
max_length: maximum allowable length of the response
"""
d: defer.Deferred[MultipartResponse] = defer.Deferred()

# If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body.
if max_length is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_length:
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
return d

response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
return d


def encode_query_args(args: Optional[QueryParams]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
Expand Down
Loading