diff --git a/api/_apimeta/auth.go b/api/_apimeta/auth.go index 7bc7d4fb..37ec87ba 100644 --- a/api/_apimeta/auth.go +++ b/api/_apimeta/auth.go @@ -16,6 +16,10 @@ type UserInfo struct { IsShared bool } +type ServerInfo struct { + ServerName string +} + func GetRequestUserAdminStatus(r *http.Request, rctx rcontext.RequestContext, user UserInfo) (bool, bool) { isGlobalAdmin := util.IsGlobalAdmin(user.UserId) || user.IsShared isLocalAdmin, err := matrix.IsUserAdmin(rctx, r.Host, user.AccessToken, r.RemoteAddr) diff --git a/api/_auth_cache/auth_cache.go b/api/_auth_cache/auth_cache.go index cb66fb4a..595ea9b5 100644 --- a/api/_auth_cache/auth_cache.go +++ b/api/_auth_cache/auth_cache.go @@ -12,7 +12,7 @@ import ( "github.com/turt2live/matrix-media-repo/matrix" ) -var tokenCache = cache.New(0*time.Second, 30*time.Second) +var tokenCache = cache.New(cache.NoExpiration, 30*time.Second) var rwLock = &sync.RWMutex{} var regexCache = make(map[string]*regexp.Regexp) diff --git a/api/_routers/97-require-server-auth.go b/api/_routers/97-require-server-auth.go new file mode 100644 index 00000000..b7ca3bac --- /dev/null +++ b/api/_routers/97-require-server-auth.go @@ -0,0 +1,30 @@ +package _routers + +import ( + "net/http" + + "github.com/turt2live/matrix-media-repo/api/_apimeta" + "github.com/turt2live/matrix-media-repo/api/_responses" + "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/matrix" +) + +type GeneratorWithServerFn = func(r *http.Request, ctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} + +func RequireServerAuth(generator GeneratorWithServerFn) GeneratorFn { + return func(r *http.Request, ctx rcontext.RequestContext) interface{} { + serverName, err := matrix.ValidateXMatrixAuth(r, true) + if err != nil { + ctx.Log.Debug("Error with X-Matrix auth: ", err) + return &_responses.ErrorResponse{ + Code: common.ErrCodeForbidden, + Message: "no auth provided (required)", + InternalCode: common.ErrCodeMissingToken, + } + } + return generator(r, ctx, _apimeta.ServerInfo{ + ServerName: serverName, + }) + } +} diff --git a/api/_routers/98-use-rcontext.go b/api/_routers/98-use-rcontext.go index 6a3b8e9f..ba009f94 100644 --- a/api/_routers/98-use-rcontext.go +++ b/api/_routers/98-use-rcontext.go @@ -88,20 +88,24 @@ func (c *RContextRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { beforeParseDownload: log.Infof("Replying with result: %T %+v", res, res) if downloadRes, isDownload := res.(*_responses.DownloadResponse); isDownload { - ranges, err := http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes) - if errors.Is(err, http_range.ErrInvalid) { - proposedStatusCode = http.StatusRequestedRangeNotSatisfiable - res = _responses.BadRequest("invalid range header") - goto beforeParseDownload // reprocess `res` - } else if errors.Is(err, http_range.ErrNoOverlap) { - proposedStatusCode = http.StatusRequestedRangeNotSatisfiable - res = _responses.BadRequest("out of range") - goto beforeParseDownload // reprocess `res` - } - if len(ranges) > 1 { - proposedStatusCode = http.StatusRequestedRangeNotSatisfiable - res = _responses.BadRequest("only 1 range is supported") - goto beforeParseDownload // reprocess `res` + var ranges []http_range.Range + var err error + if downloadRes.SizeBytes > 0 { + ranges, err = http_range.ParseRange(r.Header.Get("Range"), downloadRes.SizeBytes, rctx.Config.Downloads.DefaultRangeChunkSizeBytes) + if errors.Is(err, http_range.ErrInvalid) { + proposedStatusCode = http.StatusRequestedRangeNotSatisfiable + res = _responses.BadRequest("invalid range header") + goto beforeParseDownload // reprocess `res` + } else if errors.Is(err, http_range.ErrNoOverlap) { + proposedStatusCode = http.StatusRequestedRangeNotSatisfiable + res = _responses.BadRequest("out of range") + goto beforeParseDownload // reprocess `res` + } + if len(ranges) > 1 { + proposedStatusCode = http.StatusRequestedRangeNotSatisfiable + res = _responses.BadRequest("only 1 range is supported") + goto beforeParseDownload // reprocess `res` + } } contentType = "application/octet-stream" diff --git a/api/custom/federation.go b/api/custom/federation.go index 80921ff0..ed324488 100644 --- a/api/custom/federation.go +++ b/api/custom/federation.go @@ -34,6 +34,9 @@ func GetFederationInfo(r *http.Request, rctx rcontext.RequestContext, user _apim versionUrl := url + "/_matrix/federation/v1/version" versionResponse, err := matrix.FederatedGet(versionUrl, hostname, rctx) + if versionResponse != nil { + defer versionResponse.Body.Close() + } if err != nil { rctx.Log.Error(err) sentry.CaptureException(err) diff --git a/api/routes.go b/api/routes.go index 808646dd..49580f18 100644 --- a/api/routes.go +++ b/api/routes.go @@ -18,6 +18,7 @@ import ( const PrefixMedia = "/_matrix/media" const PrefixClient = "/_matrix/client" +const PrefixFederation = "/_matrix/federation" func buildRoutes() http.Handler { counter := &_routers.RequestCounter{} @@ -36,13 +37,29 @@ func buildRoutes() http.Handler { register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId/:filename", mxSpecV3Transition, router, downloadRoute) register([]string{"GET"}, PrefixMedia, "download/:server/:mediaId", mxSpecV3Transition, router, downloadRoute) register([]string{"GET"}, PrefixMedia, "thumbnail/:server/:mediaId", mxSpecV3Transition, router, makeRoute(_routers.OptionalAccessToken(r0.ThumbnailMedia), "thumbnail", counter)) - register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter)) + previewUrlRoute := makeRoute(_routers.RequireAccessToken(r0.PreviewUrl), "url_preview", counter) + register([]string{"GET"}, PrefixMedia, "preview_url", mxSpecV3TransitionCS, router, previewUrlRoute) register([]string{"GET"}, PrefixMedia, "identicon/*seed", mxR0, router, makeRoute(_routers.OptionalAccessToken(r0.Identicon), "identicon", counter)) - register([]string{"GET"}, PrefixMedia, "config", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.PublicConfig), "config", counter)) + configRoute := makeRoute(_routers.RequireAccessToken(r0.PublicConfig), "config", counter) + register([]string{"GET"}, PrefixMedia, "config", mxSpecV3TransitionCS, router, configRoute) register([]string{"POST"}, PrefixClient, "logout", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.Logout), "logout", counter)) register([]string{"POST"}, PrefixClient, "logout/all", mxSpecV3TransitionCS, router, makeRoute(_routers.RequireAccessToken(r0.LogoutAll), "logout_all", counter)) register([]string{"POST"}, PrefixMedia, "create", mxV1, router, makeRoute(_routers.RequireAccessToken(v1.CreateMedia), "create", counter)) + // MSC3916 - Authentication & endpoint API separation + register([]string{"GET"}, PrefixClient, "media/preview_url", msc3916, router, previewUrlRoute) + register([]string{"GET"}, PrefixClient, "media/config", msc3916, router, configRoute) + authedDownloadRoute := makeRoute(_routers.RequireAccessToken(unstable.ClientDownloadMedia), "download", counter) + register([]string{"GET"}, PrefixClient, "media/download/:server/:mediaId/:filename", msc3916, router, authedDownloadRoute) + register([]string{"GET"}, PrefixClient, "media/download/:server/:mediaId", msc3916, router, authedDownloadRoute) + register([]string{"GET"}, PrefixClient, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireAccessToken(r0.ThumbnailMedia), "thumbnail", counter)) + register([]string{"GET"}, PrefixFederation, "media/download/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationDownloadMedia), "download", counter)) + register([]string{"GET"}, PrefixFederation, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationThumbnailMedia), "thumbnail", counter)) + + // MSC3911 - Linking media to events + register([]string{"POST"}, PrefixClient, "media/upload", msc3911, router, makeRoute(_routers.RequireAccessToken(unstable.ClientUploadMediaSync), "upload", counter)) + register([]string{"POST"}, PrefixClient, "media/create", msc3911, router, makeRoute(_routers.RequireAccessToken(unstable.ClientCreateMedia), "create", counter)) + // Custom features register([]string{"GET"}, PrefixMedia, "local_copy/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.LocalCopy), "local_copy", counter)) register([]string{"GET"}, PrefixMedia, "info/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.MediaInfo), "info", counter)) @@ -129,6 +146,8 @@ var ( //mxAllSpec matrixVersions = []string{"r0", "v1", "v3", "unstable", "unstable/io.t2bot.media" /* and MSC routes */} mxUnstable matrixVersions = []string{"unstable", "unstable/io.t2bot.media"} msc4034 matrixVersions = []string{"unstable/org.matrix.msc4034"} + msc3916 matrixVersions = []string{"unstable/org.matrix.msc3916"} + msc3911 matrixVersions = []string{"unstable/org.matrix.msc3911"} mxSpecV3Transition matrixVersions = []string{"r0", "v1", "v3"} mxSpecV3TransitionCS matrixVersions = []string{"r0", "v3"} mxR0 matrixVersions = []string{"r0"} diff --git a/api/unstable/msc3911_create.go b/api/unstable/msc3911_create.go new file mode 100644 index 00000000..bea65622 --- /dev/null +++ b/api/unstable/msc3911_create.go @@ -0,0 +1,50 @@ +package unstable + +import ( + "net/http" + + "github.com/getsentry/sentry-go" + "github.com/turt2live/matrix-media-repo/api/_apimeta" + "github.com/turt2live/matrix-media-repo/api/_responses" + v1 "github.com/turt2live/matrix-media-repo/api/v1" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/pipelines/pipeline_create" + "github.com/turt2live/matrix-media-repo/util" +) + +func ClientCreateMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + id, err := restrictAsyncMediaId(rctx, r.Host, user.UserId) + if err != nil { + rctx.Log.Error("Unexpected error creating media ID:", err) + sentry.CaptureException(err) + return _responses.InternalServerError("unexpected error") + } + + return &v1.MediaCreatedResponse{ + ContentUri: util.MxcUri(id.Origin, id.MediaId), + ExpiresTs: id.ExpiresTs, + } +} + +func restrictAsyncMediaId(ctx rcontext.RequestContext, host string, userId string) (*database.DbExpiringMedia, error) { + id, err := pipeline_create.Execute(ctx, host, userId, pipeline_create.DefaultExpirationTime) + if err != nil { + return nil, err + } + + db := database.GetInstance().RestrictedMedia.Prepare(ctx) + err = db.Insert(id.Origin, id.MediaId, database.RestrictedToUser, id.UserId) + if err != nil { + // Try to clean up the expiring record, but don't fail if it fails + err2 := database.GetInstance().ExpiringMedia.Prepare(ctx).SetExpiry(id.Origin, id.MediaId, util.NowMillis()) + if err2 != nil { + ctx.Log.Warn("Non-fatal error when trying to clean up interstitial expiring media: ", err2) + sentry.CaptureException(err2) + } + + return nil, err + } + + return id, nil +} diff --git a/api/unstable/msc3911_upload_sync.go b/api/unstable/msc3911_upload_sync.go new file mode 100644 index 00000000..b03a43db --- /dev/null +++ b/api/unstable/msc3911_upload_sync.go @@ -0,0 +1,36 @@ +package unstable + +import ( + "net/http" + + "github.com/getsentry/sentry-go" + "github.com/turt2live/matrix-media-repo/api/_apimeta" + "github.com/turt2live/matrix-media-repo/api/_responses" + "github.com/turt2live/matrix-media-repo/api/_routers" + "github.com/turt2live/matrix-media-repo/api/r0" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/util" +) + +func ClientUploadMediaSync(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + // We're a bit fancy here. Instead of mirroring the "upload sync" endpoint to include restricted media, we + // internally create an async media ID then claim it immediately. + + id, err := restrictAsyncMediaId(rctx, r.Host, user.UserId) + if err != nil { + rctx.Log.Error("Unexpected error creating media ID:", err) + sentry.CaptureException(err) + return _responses.InternalServerError("unexpected error") + } + + r = _routers.ForceSetParam("server", id.Origin, r) + r = _routers.ForceSetParam("mediaId", id.MediaId, r) + + resp := r0.UploadMediaAsync(r, rctx, user) + if _, ok := resp.(*r0.MediaUploadedResponse); ok { + return &r0.MediaUploadedResponse{ + ContentUri: util.MxcUri(id.Origin, id.MediaId), + } + } + return resp +} diff --git a/api/unstable/msc3916_download.go b/api/unstable/msc3916_download.go new file mode 100644 index 00000000..64d9d5b1 --- /dev/null +++ b/api/unstable/msc3916_download.go @@ -0,0 +1,37 @@ +package unstable + +import ( + "bytes" + "net/http" + + "github.com/turt2live/matrix-media-repo/api/_apimeta" + "github.com/turt2live/matrix-media-repo/api/_responses" + "github.com/turt2live/matrix-media-repo/api/r0" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/util/readers" +) + +func ClientDownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + r.URL.Query().Set("allow_remote", "true") + return r0.DownloadMedia(r, rctx, user) +} + +func FederationDownloadMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} { + r.URL.Query().Set("allow_remote", "false") + + res := r0.DownloadMedia(r, rctx, _apimeta.UserInfo{}) + if dl, ok := res.(*_responses.DownloadResponse); ok { + return &_responses.DownloadResponse{ + ContentType: "multipart/mixed", + Filename: "", + SizeBytes: 0, + Data: readers.NewMultipartReader( + &readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))}, + &readers.MultipartPart{ContentType: dl.ContentType, FileName: dl.Filename, Reader: dl.Data}, + ), + TargetDisposition: "attachment", + } + } else { + return res + } +} diff --git a/api/unstable/msc3916_thumbnail.go b/api/unstable/msc3916_thumbnail.go new file mode 100644 index 00000000..81b77d20 --- /dev/null +++ b/api/unstable/msc3916_thumbnail.go @@ -0,0 +1,32 @@ +package unstable + +import ( + "bytes" + "net/http" + + "github.com/turt2live/matrix-media-repo/api/_apimeta" + "github.com/turt2live/matrix-media-repo/api/_responses" + "github.com/turt2live/matrix-media-repo/api/r0" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/util/readers" +) + +func FederationThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, server _apimeta.ServerInfo) interface{} { + r.URL.Query().Set("allow_remote", "false") + + res := r0.ThumbnailMedia(r, rctx, _apimeta.UserInfo{}) + if dl, ok := res.(*_responses.DownloadResponse); ok { + return &_responses.DownloadResponse{ + ContentType: "multipart/mixed", + Filename: "", + SizeBytes: 0, + Data: readers.NewMultipartReader( + &readers.MultipartPart{ContentType: "application/json", Reader: readers.MakeCloser(bytes.NewReader([]byte("{}")))}, + &readers.MultipartPart{ContentType: dl.ContentType, FileName: dl.Filename, Reader: dl.Data}, + ), + TargetDisposition: "attachment", + } + } else { + return res + } +} diff --git a/config.sample.yaml b/config.sample.yaml index 98d278a5..a4a50412 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -556,7 +556,18 @@ plugins: # Sections of this config might disappear or be added over time. By default all # features are disabled in here and must be explicitly enabled to be used. featureSupport: - # No unstable features are currently supported. + # MSC3911 enables linking media to events, allowing the associated media to be + # deleted when the event is (fully) deleted. MSC3911 support is always enabled + # and requires changes to either the homeserver or how the homeserver is deployed + # to work. + MSC3911: + # How long a "restricted" item of media can exist before it is automatically + # purged from the server. Defaults to 10 minutes. + maxRestrictedAgeMinutes: 10 + + # The maximum number of media items that can be attached to a single event. + # Defaults to 20. + maxAttachEvent: 20 # Support for redis as a cache mechanism # diff --git a/database/db.go b/database/db.go index 9b027ce7..dabb4941 100644 --- a/database/db.go +++ b/database/db.go @@ -28,6 +28,7 @@ type Database struct { Tasks *tasksTableStatements Exports *exportsTableStatements ExportParts *exportPartsTableStatements + RestrictedMedia *restrictedMediaTableStatements } var instance *Database @@ -124,6 +125,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error if d.ExportParts, err = prepareExportPartsTables(d.conn); err != nil { return errors.New("failed to create export parts table accessor: " + err.Error()) } + if d.RestrictedMedia, err = prepareRestrictedMediaTables(d.conn); err != nil { + return errors.New("failed to create restricted media table accessor: " + err.Error()) + } instance = d return nil diff --git a/database/table_expiring_media.go b/database/table_expiring_media.go index 9dfa0854..2bde322b 100644 --- a/database/table_expiring_media.go +++ b/database/table_expiring_media.go @@ -23,6 +23,7 @@ const insertExpiringMedia = "INSERT INTO expiring_media (origin, media_id, user_ const selectExpiringMediaByUserCount = "SELECT COUNT(*) FROM expiring_media WHERE user_id = $1 AND expires_ts >= $2;" const selectExpiringMediaById = "SELECT origin, media_id, user_id, expires_ts FROM expiring_media WHERE origin = $1 AND media_id = $2;" const deleteExpiringMediaById = "DELETE FROM expiring_media WHERE origin = $1 AND media_id = $2;" +const updateExpiringMediaExpiration = "UPDATE expiring_media SET expires_ts = $3 WHERE origin = $1 AND media_id = $2;" // Dev note: there is an UPDATE query in the Upload test suite. @@ -31,6 +32,7 @@ type expiringMediaTableStatements struct { selectExpiringMediaByUserCount *sql.Stmt selectExpiringMediaById *sql.Stmt deleteExpiringMediaById *sql.Stmt + updateExpiringMediaExpiration *sql.Stmt } type expiringMediaTableWithContext struct { @@ -54,6 +56,9 @@ func prepareExpiringMediaTables(db *sql.DB) (*expiringMediaTableStatements, erro if stmts.deleteExpiringMediaById, err = db.Prepare(deleteExpiringMediaById); err != nil { return nil, errors.New("error preparing deleteExpiringMediaById: " + err.Error()) } + if stmts.updateExpiringMediaExpiration, err = db.Prepare(updateExpiringMediaExpiration); err != nil { + return nil, errors.New("error preparing updateExpiringMediaExpiration: " + err.Error()) + } return stmts, nil } @@ -96,3 +101,8 @@ func (s *expiringMediaTableWithContext) Delete(origin string, mediaId string) er _, err := s.statements.deleteExpiringMediaById.ExecContext(s.ctx, origin, mediaId) return err } + +func (s *expiringMediaTableWithContext) SetExpiry(origin string, mediaId string, expiresTs int64) error { + _, err := s.statements.updateExpiringMediaExpiration.ExecContext(s.ctx, origin, mediaId, expiresTs) + return err +} diff --git a/database/table_restricted_media.go b/database/table_restricted_media.go new file mode 100644 index 00000000..5f1fff65 --- /dev/null +++ b/database/table_restricted_media.go @@ -0,0 +1,81 @@ +package database + +import ( + "database/sql" + "errors" + + "github.com/turt2live/matrix-media-repo/common/rcontext" +) + +type RestrictedCondition string + +const RestrictedToEvent RestrictedCondition = "event_id" // MSC3911 +const RestrictedToProfile RestrictedCondition = "profile_user_id" // MSC3911 +const RestrictedToUser RestrictedCondition = "io.t2bot.user_id" // Internal extension + +type DbRestrictedMedia struct { + Origin string + MediaId string + Condition RestrictedCondition + ConditionValue string +} + +const insertRestrictedMedia = "INSERT INTO restricted_media (origin, media_id, condition_type, condition_value) VALUES ($1, $2, $3, $4);" +const updateRestrictedMedia = "UPDATE restricted_media SET condition_type = $3, condition_value = $4 WHERE origin = $1 AND media_id = $2;" +const selectRestrictedMedia = "SELECT origin, media_id, condition_type, condition_value FROM restricted_media WHERE origin = $1 AND media_id = $2 LIMIT 1;" + +type restrictedMediaTableStatements struct { + insertRestrictedMedia *sql.Stmt + updateRestrictedMedia *sql.Stmt + selectRestrictedMedia *sql.Stmt +} + +type restrictedMediaTableWithContext struct { + statements *restrictedMediaTableStatements + ctx rcontext.RequestContext +} + +func prepareRestrictedMediaTables(db *sql.DB) (*restrictedMediaTableStatements, error) { + var err error + var stmts = &restrictedMediaTableStatements{} + + if stmts.insertRestrictedMedia, err = db.Prepare(insertRestrictedMedia); err != nil { + return nil, errors.New("error preparing insertRestrictedMedia: " + err.Error()) + } + if stmts.updateRestrictedMedia, err = db.Prepare(updateRestrictedMedia); err != nil { + return nil, errors.New("error preparing updateRestrictedMedia: " + err.Error()) + } + if stmts.selectRestrictedMedia, err = db.Prepare(selectRestrictedMedia); err != nil { + return nil, errors.New("error preparing selectRestrictedMedia: " + err.Error()) + } + + return stmts, nil +} + +func (s *restrictedMediaTableStatements) Prepare(ctx rcontext.RequestContext) *restrictedMediaTableWithContext { + return &restrictedMediaTableWithContext{ + statements: s, + ctx: ctx, + } +} + +func (s *restrictedMediaTableWithContext) Insert(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error { + _, err := s.statements.insertRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue) + return err +} + +func (s *restrictedMediaTableWithContext) Update(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error { + _, err := s.statements.updateRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue) + return err +} + +func (s *restrictedMediaTableWithContext) GetById(origin string, mediaId string) (*DbRestrictedMedia, error) { + row := s.statements.selectRestrictedMedia.QueryRowContext(s.ctx, origin, mediaId) + val := &DbRestrictedMedia{} + err := row.Scan(&val.Origin, &val.MediaId, &val.Condition, &val.ConditionValue) + if errors.Is(err, sql.ErrNoRows) { + err = nil + val = nil + } + return val, err +} diff --git a/docs/msc3911.md b/docs/msc3911.md new file mode 100644 index 00000000..353208bd --- /dev/null +++ b/docs/msc3911.md @@ -0,0 +1,65 @@ +# [MSC3911](https://github.com/matrix-org/matrix-spec-proposals/pull/3911) API + +*Note*: These docs may change as the MSC progresses towards stability. + +This document describes MMR's API requirements for supporting MSC3911. + +## Authentication + +Servers making requests to MMR's MSC3911 APIs *MUST* include an `Authorization` header containing a `Bearer` token supplied +in the MMR config. This is a per-server token used to identify the sending server. + +Requests lacking this authentication will receive a standard Matrix error response of `{"errcode":"M_UNAUTHORIZED"}` and +status code 401. If the request had authentication, but the server making the request does not have appropriate ability +to affect the given resource, `{"errcode":"M_FORBIDDEN"}` and status code 403 is returned. For example, a server trying +to clean up a user belonging to another server. + +Additionally, the `Host` header received by MMR *MUST* match the requesting server's name, like with all other MMR endpoints. +The server name is the name used in user IDs. + +## Event sending + +The [`PUT /_matrix/client/v3/rooms/{roomId}/state/{eventType}/{stateKey}`](https://spec.matrix.org/v1.8/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey) +and [`PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`](https://spec.matrix.org/v1.8/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid) +endpoints support a new `attach_media` repeated parameter under MSC3911. While unstable, this new parameter is `org.matrix.msc3911.attach_media`. + +The MMR API to associate an event ID with the user-specified media is `POST /_mmr/v1/attach/event/{eventId}` with JSON body +being the list of associated MXC URIs. If unsuccessful, MMR will return a standard Matrix error response consistent with +MSC3911: `{"errcode":"M_INVALID_PARAM"}` and 400 status code. + +## User profiles + +Similar to event sending, [`PUT /_matrix/client/v3/profile/{userId}/avatar_url`](https://spec.matrix.org/v1.8/client-server-api/#put_matrixclientv3profileuseridavatar_url) +automatically supports attaching to the user's profile. + +The MMR API to associate that media with the user's profile is `POST /_mmr/v1/attach/user/{userId}` with JSON body being +a list of associated MXC URIs (should normally be a single entry). If unsuccessful, MMR will return a standard Matrix +error response consistent with MSC3911: `{"errcode":"M_INVALID_PARAM"}` and 400 status code. + +## Redactions + +When a redaction occurs, the server should hit `POST /_mmr/v1/cleanup/event/{eventId}` with an empty JSON object for a body. + +MMR will return `{"deleted":2}` to denote how many pieces of media were removed. If there were no linked media items, +`deleted: 0` will be returned. + +## User deactivations + +When a user is deactivated, the server should hit `POST /_mmr/v1/cleanup/user/{userId}` with an empty JSON object for a +body. + +MMR will return `{"deleted":2}` to denote how many pieces of media were removed. If there were no linked media items, +`deleted: 0` will be returned. + +## Visibility check + +When MMR needs to ensure a given user or server can see a particular event or profile, it will use the following endpoints: + +* `GET /_mmr-upstream/v1/visibility/event/{targetEventId}/user/{viewingUserId}` with no request body. +* `GET /_mmr-upstream/v1/visibility/event/{targetEventId}/server/{viewingServerName}` with no request body. +* `GET /_mmr-upstream/v1/visibility/profile/{targetUserId}/user/{viewingUserId}` with no request body. +* `GET /_mmr-upstream/v1/visibility/profile/{targetUserId}/server/{viewingServerName}` with no request body. + +In all cases, a `404 M_NOT_FOUND` error should be returned if visibility is not possible or the parameters are malformed. +If the server does not support the endpoints, it *MUST* return `404 M_UNRECOGNIZED` per the Matrix specification on unknown +endpoints. If visibility is permitted, a 200 status code is returned. MMR will discard the response body on 200 OK. diff --git a/matrix/requests_signing.go b/matrix/requests_signing.go new file mode 100644 index 00000000..15e37997 --- /dev/null +++ b/matrix/requests_signing.go @@ -0,0 +1,150 @@ +package matrix + +import ( + "crypto/ed25519" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "github.com/patrickmn/go-cache" + "github.com/sirupsen/logrus" + "github.com/t2bot/go-typed-singleflight" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/util" +) + +type signingKey struct { + Key string `json:"key"` +} + +type serverKeyResult struct { + ServerName string `json:"server_name"` + ValidUntilTs int64 `json:"valid_until_ts"` + VerifyKeys map[string]signingKey `json:"verify_keys"` // unpadded base64 + OldVerifyKeys map[string]signingKey `json:"old_verify_keys"` // unpadded base64 + Signatures map[string]map[string]string `json:"signatures"` // unpadded base64; > +} + +type ServerSigningKeys map[string]ed25519.PublicKey + +var signingKeySf = new(typedsf.Group[*ServerSigningKeys]) +var signingKeyCache = cache.New(cache.NoExpiration, 30*time.Second) +var signingKeyRWLock = new(sync.RWMutex) + +func querySigningKeyCache(serverName string) *ServerSigningKeys { + if val, ok := signingKeyCache.Get(serverName); ok { + return val.(*ServerSigningKeys) + } + return nil +} + +func QuerySigningKeys(serverName string) (*ServerSigningKeys, error) { + signingKeyRWLock.RLock() + keys := querySigningKeyCache(serverName) + signingKeyRWLock.RUnlock() + if keys != nil { + return keys, nil + } + + keys, err, _ := signingKeySf.Do(serverName, func() (*ServerSigningKeys, error) { + ctx := rcontext.Initial().LogWithFields(logrus.Fields{ + "keysForServer": serverName, + }) + + signingKeyRWLock.Lock() + defer signingKeyRWLock.Unlock() + + // check cache once more, just in case the locks overlapped + cachedKeys := querySigningKeyCache(serverName) + if keys != nil { + return cachedKeys, nil + } + + // now we can try to get the keys from the source + url, hostname, err := GetServerApiUrl(serverName) + if err != nil { + return nil, err + } + + keysUrl := url + "/_matrix/key/v2/server" + keysResponse, err := FederatedGet(keysUrl, hostname, ctx) + if keysResponse != nil { + defer keysResponse.Body.Close() + } + if err != nil { + return nil, err + } + + decoder := json.NewDecoder(keysResponse.Body) + raw := database.AnonymousJson{} + if err = decoder.Decode(&raw); err != nil { + return nil, err + } + keyInfo := new(serverKeyResult) + if err = raw.ApplyTo(keyInfo); err != nil { + return nil, err + } + + // Check validity before we go much further + if keyInfo.ServerName != serverName { + return nil, fmt.Errorf("got keys for '%s' but expected '%s'", keyInfo.ServerName, serverName) + } + if keyInfo.ValidUntilTs <= util.NowMillis() { + return nil, errors.New("returned server keys are expired") + } + cacheUntil := time.Until(time.UnixMilli(keyInfo.ValidUntilTs)) / 2 + if cacheUntil <= (6 * time.Second) { + return nil, errors.New("returned server keys would expire too quickly") + } + + // Convert to something useful + serverKeys := make(ServerSigningKeys) + for keyId, keyObj := range keyInfo.VerifyKeys { + b, err := util.DecodeUnpaddedBase64String(keyObj.Key) + if err != nil { + return nil, errors.Join(fmt.Errorf("bad base64 for key ID '%s' for '%s'", keyId, serverName), err) + } + + serverKeys[keyId] = b + } + + // Check signatures + if len(keyInfo.Signatures) == 0 || len(keyInfo.Signatures[serverName]) == 0 { + return nil, fmt.Errorf("missing signatures from '%s'", serverName) + } + delete(raw, "signatures") + canonical, err := util.EncodeCanonicalJson(raw) + if err != nil { + return nil, err + } + for domain, sig := range keyInfo.Signatures { + if domain != serverName { + return nil, fmt.Errorf("unexpected signature from '%s' (expected '%s')", domain, serverName) + } + + for keyId, b64 := range sig { + signatureBytes, err := util.DecodeUnpaddedBase64String(b64) + if err != nil { + return nil, errors.Join(fmt.Errorf("bad base64 signature for key ID '%s' for '%s'", keyId, serverName), err) + } + + key, ok := serverKeys[keyId] + if !ok { + return nil, fmt.Errorf("unknown key ID '%s' for signature from '%s'", keyId, serverName) + } + + if !ed25519.Verify(key, canonical, signatureBytes) { + return nil, fmt.Errorf("invalid signature '%s' from key ID '%s' for '%s'", b64, keyId, serverName) + } + } + } + + // Cache & return (unlock was deferred) + signingKeyCache.Set(serverName, &serverKeys, cacheUntil) + return &serverKeys, nil + }) + return keys, err +} diff --git a/matrix/xmatrix.go b/matrix/xmatrix.go new file mode 100644 index 00000000..72482f13 --- /dev/null +++ b/matrix/xmatrix.go @@ -0,0 +1,65 @@ +package matrix + +import ( + "crypto/ed25519" + "errors" + "fmt" + "github.com/turt2live/matrix-media-repo/util" + "net/http" +) + +var ErrNoXMatrixAuth = errors.New("no X-Matrix auth headers") + +func ValidateXMatrixAuth(request *http.Request, expectNoContent bool) (string, error) { + if !expectNoContent { + panic("development error: X-Matrix auth validation can only be done with an empty body for now") + } + + auths, err := util.GetXMatrixAuth(request) + if err != nil { + return "", err + } + + if len(auths) == 0 { + return "", ErrNoXMatrixAuth + } + + obj := map[string]interface{}{ + "method": request.Method, + "uri": request.RequestURI, + "origin": auths[0].Origin, + "destination": auths[0].Destination, + "content": "{}", + } + canonical, err := util.EncodeCanonicalJson(obj) + if err != nil { + return "", err + } + + keys, err := QuerySigningKeys(auths[0].Origin) + if err != nil { + return "", err + } + + for _, h := range auths { + if h.Origin != obj["origin"] { + return "", errors.New("auth is from multiple servers") + } + if h.Destination != obj["destination"] { + return "", errors.New("auth is for multiple servers") + } + if h.Destination != "" && !util.IsServerOurs(h.Destination) { + return "", errors.New("unknown destination") + } + + if key, ok := (*keys)[h.KeyId]; ok { + if !ed25519.Verify(key, canonical, h.Signature) { + return "", fmt.Errorf("failed signatures on '%s'", h.KeyId) + } + } else { + return "", fmt.Errorf("unknown key '%s'", h.KeyId) + } + } + + return auths[0].Origin, nil +} diff --git a/migrations/27_create_restricted_media_table_down.sql b/migrations/27_create_restricted_media_table_down.sql new file mode 100644 index 00000000..62ae20d6 --- /dev/null +++ b/migrations/27_create_restricted_media_table_down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_restricted_media; +DROP TABLE IF EXISTS restricted_media; diff --git a/migrations/27_create_restricted_media_table_up.sql b/migrations/27_create_restricted_media_table_up.sql new file mode 100644 index 00000000..13b0fa6c --- /dev/null +++ b/migrations/27_create_restricted_media_table_up.sql @@ -0,0 +1,2 @@ +CREATE TABLE IF NOT EXISTS restricted_media (origin TEXT NOT NULL, media_id TEXT NOT NULL, condition_type TEXT NOT NULL, condition_value TEXT NOT NULL); +CREATE UNIQUE INDEX IF NOT EXISTS idx_restricted_media ON restricted_media (origin, media_id); diff --git a/test/canonical_json_test.go b/test/canonical_json_test.go new file mode 100644 index 00000000..c0756ddc --- /dev/null +++ b/test/canonical_json_test.go @@ -0,0 +1,130 @@ +/* + * Copyright 2019 Travis Ralston + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test + +import ( + "testing" + + "github.com/turt2live/matrix-media-repo/util" +) + +func TestEncodeCanonicalJson_CaseA(t *testing.T) { + input := map[string]interface{}{} + expectedOutput := []byte("{}") + actualOutput, _ := util.EncodeCanonicalJson(input) + compareBytes(expectedOutput, actualOutput, t) +} + +func TestEncodeCanonicalJson_CaseB(t *testing.T) { + input := map[string]interface{}{ + "one": 1, + "two": "Two", + } + expectedOutput := []byte("{\"one\":1,\"two\":\"Two\"}") + actualOutput, _ := util.EncodeCanonicalJson(input) + compareBytes(expectedOutput, actualOutput, t) +} + +func TestEncodeCanonicalJson_CaseC(t *testing.T) { + input := map[string]interface{}{ + "b": "2", + "a": "1", + } + expectedOutput := []byte("{\"a\":\"1\",\"b\":\"2\"}") + actualOutput, _ := util.EncodeCanonicalJson(input) + compareBytes(expectedOutput, actualOutput, t) +} + +func TestEncodeCanonicalJson_CaseD(t *testing.T) { + input := map[string]interface{}{ + "auth": map[string]interface{}{ + "success": true, + "mxid": "@john.doe:example.com", + "profile": map[string]interface{}{ + "display_name": "John Doe", + "three_pids": []map[string]interface{}{ + { + "medium": "email", + "address": "john.doe@example.org", + }, + { + "medium": "msisdn", + "address": "123456789", + }, + }, + }, + }, + } + expectedOutput := []byte("{\"auth\":{\"mxid\":\"@john.doe:example.com\",\"profile\":{\"display_name\":\"John Doe\",\"three_pids\":[{\"address\":\"john.doe@example.org\",\"medium\":\"email\"},{\"address\":\"123456789\",\"medium\":\"msisdn\"}]},\"success\":true}}") + actualOutput, _ := util.EncodeCanonicalJson(input) + compareBytes(expectedOutput, actualOutput, t) +} + +func TestEncodeCanonicalJson_CaseE(t *testing.T) { + input := map[string]interface{}{ + "a": "日本語", + } + expectedOutput := []byte("{\"a\":\"日本語\"}") + actualOutput, _ := util.EncodeCanonicalJson(input) + compareBytes(expectedOutput, actualOutput, t) +} + +func TestEncodeCanonicalJson_CaseF(t *testing.T) { + input := map[string]interface{}{ + "本": 2, + "日": 1, + } + expectedOutput := []byte("{\"日\":1,\"本\":2}") + actualOutput, _ := util.EncodeCanonicalJson(input) + compareBytes(expectedOutput, actualOutput, t) +} + +func TestEncodeCanonicalJson_CaseG(t *testing.T) { + input := map[string]interface{}{ + "a": "\u65E5", + } + expectedOutput := []byte("{\"a\":\"日\"}") + actualOutput, _ := util.EncodeCanonicalJson(input) + compareBytes(expectedOutput, actualOutput, t) +} + +func TestEncodeCanonicalJson_CaseH(t *testing.T) { + input := map[string]interface{}{ + "a": nil, + } + expectedOutput := []byte("{\"a\":null}") + actualOutput, _ := util.EncodeCanonicalJson(input) + compareBytes(expectedOutput, actualOutput, t) +} + +func compareBytes(expected []byte, actual []byte, t *testing.T) { + if len(expected) != len(actual) { + t.Errorf("Mismatched length: %d != %d", len(actual), len(expected)) + t.Fail() + return + } + + for i := range expected { + e := expected[i] + a := actual[i] + if e != a { + t.Errorf("Expected %b but got %b at index %d", e, a, i) + t.Fail() + return + } + } +} diff --git a/util/canonical_json.go b/util/canonical_json.go new file mode 100644 index 00000000..d680d51b --- /dev/null +++ b/util/canonical_json.go @@ -0,0 +1,36 @@ +/* + * Copyright 2019 Travis Ralston + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "bytes" + "encoding/json" +) + +func EncodeCanonicalJson(obj map[string]interface{}) ([]byte, error) { + b, err := json.Marshal(obj) + if err != nil { + return nil, err + } + + // De-encode values + b = bytes.Replace(b, []byte("\\u003c"), []byte("<"), -1) + b = bytes.Replace(b, []byte("\\u003e"), []byte(">"), -1) + b = bytes.Replace(b, []byte("\\u0026"), []byte("&"), -1) + + return b, nil +} diff --git a/util/http.go b/util/http.go index 2188feb3..ac5144a4 100644 --- a/util/http.go +++ b/util/http.go @@ -1,11 +1,19 @@ package util import ( + "fmt" "net/http" "net/url" "strings" ) +type XMatrixAuth struct { + Origin string + Destination string + KeyId string + Signature []byte +} + func GetAccessTokenFromRequest(request *http.Request) string { token := request.Header.Get("Authorization") @@ -40,3 +48,72 @@ func GetLogSafeUrl(r *http.Request) string { copyUrl.RawQuery = GetLogSafeQueryString(r) return copyUrl.String() } + +func GetXMatrixAuth(request *http.Request) ([]XMatrixAuth, error) { + headers := request.Header.Values("Authorization") + auths := make([]XMatrixAuth, 0) + for _, h := range headers { + if !strings.HasPrefix(h, "X-Matrix ") { + continue + } + + paramCsv := h[len("X-Matrix "):] + params := make(map[string]string) + isKey := true + keyName := "" + keyValue := "" + escape := false + for _, c := range paramCsv { + if c == ',' && isKey { + params[strings.TrimSpace(strings.ToLower(keyName))] = keyValue + keyName = "" + keyValue = "" + continue + } + if c == '=' { + isKey = false + continue + } + + if isKey { + keyName = fmt.Sprintf("%s%s", keyName, string(c)) + } else { + if c == '\\' && !escape { + escape = true + continue + } + if c == '"' && !escape { + escape = false + if len(keyValue) > 0 { + isKey = true + } + continue + } + if escape { + escape = false + } + keyValue = fmt.Sprintf("%s%s", keyValue, string(c)) + } + } + if len(keyName) > 0 && isKey { + params[strings.TrimSpace(strings.ToLower(keyName))] = keyValue + } + + sig, err := DecodeUnpaddedBase64String(params["sig"]) + if err != nil { + return nil, err + } + auth := XMatrixAuth{ + Origin: params["origin"], + Destination: params["destination"], + KeyId: params["key"], + Signature: sig, + } + if auth.Origin == "" || auth.KeyId == "" || len(auth.Signature) == 0 { + continue + } + auths = append(auths, auth) + } + + return auths, nil +} diff --git a/util/readers/multipart_reader.go b/util/readers/multipart_reader.go new file mode 100644 index 00000000..ea3c901d --- /dev/null +++ b/util/readers/multipart_reader.go @@ -0,0 +1,57 @@ +package readers + +import ( + "io" + "mime/multipart" + "net/textproto" + "net/url" + + "github.com/alioygur/is" +) + +type MultipartPart struct { + ContentType string + FileName string + Reader io.ReadCloser +} + +func NewMultipartReader(parts ...*MultipartPart) io.ReadCloser { + r, w := io.Pipe() + go func() { + mpw := multipart.NewWriter(w) + + for _, part := range parts { + headers := textproto.MIMEHeader{} + if part.ContentType != "" { + headers.Set("Content-Type", part.ContentType) + } + if part.FileName != "" { + if is.ASCII(part.FileName) { + headers.Set("Content-Disposition", "attachment; filename="+url.QueryEscape(part.FileName)) + } else { + headers.Set("Content-Disposition", "attachment; filename*=utf-8''"+url.QueryEscape(part.FileName)) + } + } + + partW, err := mpw.CreatePart(headers) + if err != nil { + _ = w.CloseWithError(err) + return + } + if _, err = io.Copy(partW, part.Reader); err != nil { + _ = w.CloseWithError(err) + return + } + if err = part.Reader.Close(); err != nil { + _ = w.CloseWithError(err) + return + } + } + + if err := mpw.Close(); err != nil { + _ = w.CloseWithError(err) + } + _ = w.Close() + }() + return MakeCloser(r) +} diff --git a/util/unpadded_base64.go b/util/unpadded_base64.go new file mode 100644 index 00000000..aeb1d382 --- /dev/null +++ b/util/unpadded_base64.go @@ -0,0 +1,29 @@ +/* + * Copyright 2019 Travis Ralston + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package util + +import ( + "encoding/base64" +) + +func DecodeUnpaddedBase64String(val string) ([]byte, error) { + return base64.RawStdEncoding.DecodeString(val) +} + +func EncodeUnpaddedBase64ToString(val []byte) string { + return base64.RawStdEncoding.EncodeToString(val) +}