diff --git a/internal/api/fileserver/servefile.go b/internal/api/fileserver/servefile.go index 5e36d02a..d9fc99b5 100644 --- a/internal/api/fileserver/servefile.go +++ b/internal/api/fileserver/servefile.go @@ -19,17 +19,17 @@ package fileserver import ( - "bytes" "fmt" "io" "net/http" "strconv" + "strings" + "codeberg.org/gruf/go-fastcopy" "github.com/gin-gonic/gin" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/iotools" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/oauth" ) @@ -87,20 +87,19 @@ func (m *Module) ServeFile(c *gin.Context) { return } - defer func() { - // close content when we're done - if content.Content != nil { - if err := content.Content.Close(); err != nil { - log.Errorf("ServeFile: error closing readcloser: %s", err) - } - } - }() - if content.URL != nil { + // This is a non-local S3 file we're proxying to. c.Redirect(http.StatusFound, content.URL.String()) return } + defer func() { + // Close content when we're done, catch errors. + if err := content.Content.Close(); err != nil { + log.Errorf("ServeFile: error closing readcloser: %s", err) + } + }() + // TODO: if the requester only accepts text/html we should try to serve them *something*. // This is mostly needed because when sharing a link to a gts-hosted file on something like mastodon, the masto servers will // attempt to look up the content to provide a preview of the link, and they ask for text/html. @@ -118,45 +117,123 @@ func (m *Module) ServeFile(c *gin.Context) { return } - // create a "slurp" buffer ;) - b := make([]byte, 64) - - // Try read the first 64 bytes into memory, to try return a more useful "not found" error. - if _, err := io.ReadFull(content.Content, b); err != nil && - (err != io.ErrUnexpectedEOF && err != io.EOF) { - err = fmt.Errorf("ServeFile: error reading from content: %w", err) - apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(err, err.Error()), m.processor.InstanceGetV1) + // Look for a provided range header. + rng := c.GetHeader("Range") + if rng == "" { + // This is a simple query for the whole file, so do a read from whole reader. + c.DataFromReader(http.StatusOK, content.ContentLength, format, content.Content, nil) return } - // reconstruct the original content reader - r := io.MultiReader(bytes.NewReader(b), content.Content) - - // Check the Range header: if this is a simple query for the whole file, we can return it now. - if c.GetHeader("Range") == "" && c.GetHeader("If-Range") == "" { - c.DataFromReader(http.StatusOK, content.ContentLength, format, r, nil) - return - } - - // Range is set, so we need a ReadSeeker to pass to the ServeContent function. - tfs, err := iotools.TempFileSeeker(r) - if err != nil { - err = fmt.Errorf("ServeFile: error creating temp file seeker: %w", err) - apiutil.ErrorHandler(c, gtserror.NewErrorInternalError(err), m.processor.InstanceGetV1) - return - } - defer func() { - if err := tfs.Close(); err != nil { - log.Errorf("ServeFile: error closing temp file seeker: %s", err) - } - }() - - // to avoid ServeContent wasting time seeking for the - // mime type, set this header already since we know it + // Set known content-type and serve this file range. c.Header("Content-Type", format) - - // allow ServeContent to handle the rest of the request; - // it will handle Range as appropriate, and write correct - // response headers, http code, etc - http.ServeContent(c.Writer, c.Request, fileName, content.ContentUpdated, tfs) + serveFileRange(c.Writer, content.Content, rng, content.ContentLength) +} + +// serveFileRange serves the range of a file from a given source reader, without the +// need for implementation of io.Seeker. Instead we read the first 'start' many bytes +// into a discard reader. Code is adapted from https://codeberg.org/gruf/simplehttp. +func serveFileRange(rw http.ResponseWriter, src io.Reader, rng string, size int64) { + var i int + + if i = strings.IndexByte(rng, '='); i < 0 { + // Range must include a separating '=' to indicate start + http.Error(rw, "Bad Range Header", http.StatusBadRequest) + return + } + + if rng[:i] != "bytes" { + // We only support byte ranges in our implementation + http.Error(rw, "Unsupported Range Unit", http.StatusBadRequest) + return + } + + // Reslice past '=' + rng = rng[i+1:] + + if i = strings.IndexByte(rng, '-'); i < 0 { + // Range header must contain a beginning and end separated by '-' + http.Error(rw, "Bad Range Header", http.StatusBadRequest) + return + } + + var ( + err error + + // default start + end ranges + start, end = int64(0), size - 1 + + // start + end range strings + startRng, endRng string + ) + + if startRng = rng[:i]; len(startRng) > 0 { + // Parse the start of this byte range + start, err = strconv.ParseInt(startRng, 10, 64) + if err != nil { + http.Error(rw, "Bad Range Header", http.StatusBadRequest) + return + } + + if start < 0 { + // This range starts *before* the file start, why did they send this lol + rw.Header().Set("Content-Range", "bytes *"+strconv.FormatInt(size, 10)) + http.Error(rw, "Unsatisfiable Range", http.StatusRequestedRangeNotSatisfiable) + return + } + } else { + // No start supplied, implying file start + startRng = "0" + } + + if endRng = rng[i+1:]; len(endRng) > 0 { + // Parse the end of this byte range + end, err = strconv.ParseInt(endRng, 10, 64) + if err != nil { + http.Error(rw, "Bad Range Header", http.StatusBadRequest) + return + } + + if end > size { + // This range exceeds length of the file, therefore unsatisfiable + rw.Header().Set("Content-Range", "bytes *"+strconv.FormatInt(size, 10)) + http.Error(rw, "Unsatisfiable Range", http.StatusRequestedRangeNotSatisfiable) + return + } + } else { + // No end supplied, implying file end + endRng = strconv.FormatInt(end, 10) + } + + if start >= end { + // This range starts _after_ their range end, unsatisfiable and nonsense! + rw.Header().Set("Content-Range", "bytes *"+strconv.FormatInt(size, 10)) + http.Error(rw, "Unsatisfiable Range", http.StatusRequestedRangeNotSatisfiable) + return + } + + // Dump the first 'start' many bytes into the void... + if _, err := fastcopy.CopyN(io.Discard, src, start); err != nil { + log.Errorf("error reading from source: %v", err) + return + } + + // Determine content len + length := end - start + + if end < size-1 { + // Range end < file end, limit the reader + src = io.LimitReader(src, length) + } + + // Write the necessary length and range headers + rw.Header().Set("Content-Range", "bytes "+startRng+"-"+endRng+"/"+strconv.FormatInt(size, 10)) + rw.Header().Set("Content-Length", strconv.FormatInt(length, 10)) + rw.WriteHeader(http.StatusPartialContent) + + // Read the "seeked" source reader into destination writer. + if _, err := fastcopy.Copy(rw, src); err != nil { + log.Errorf("error reading from source: %v", err) + return + } }