[header-override-650] implement header overriding in GetObject (fix #650)

This commit is contained in:
Alex Auvolat 2024-02-09 15:34:42 +01:00
parent 3865080c35
commit 02e98e2d10
No known key found for this signature in database
GPG key ID: 0E496D15096376BE
5 changed files with 131 additions and 8 deletions

View file

@ -178,8 +178,26 @@ impl ApiHandler for S3ApiServer {
key, part_number, .. key, part_number, ..
} => handle_head(garage, &req, bucket_id, &key, part_number).await, } => handle_head(garage, &req, bucket_id, &key, part_number).await,
Endpoint::GetObject { Endpoint::GetObject {
key, part_number, .. key,
} => handle_get(garage, &req, bucket_id, &key, part_number).await, part_number,
response_cache_control,
response_content_disposition,
response_content_encoding,
response_content_language,
response_content_type,
response_expires,
..
} => {
let overrides = GetObjectOverrides {
response_cache_control,
response_content_disposition,
response_content_encoding,
response_content_language,
response_content_type,
response_expires,
};
handle_get(garage, &req, bucket_id, &key, part_number, overrides).await
}
Endpoint::UploadPart { Endpoint::UploadPart {
key, key,
part_number, part_number,

View file

@ -1,12 +1,14 @@
//! Function related to GET and HEAD requests //! Function related to GET and HEAD requests
use std::convert::TryInto;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, UNIX_EPOCH}; use std::time::{Duration, UNIX_EPOCH};
use futures::future; use futures::future;
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use http::header::{ use http::header::{
ACCEPT_RANGES, CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, ETAG, IF_MODIFIED_SINCE, ACCEPT_RANGES, CACHE_CONTROL, CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LANGUAGE,
IF_NONE_MATCH, LAST_MODIFIED, RANGE, CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, ETAG, EXPIRES, IF_MODIFIED_SINCE, IF_NONE_MATCH,
LAST_MODIFIED, RANGE,
}; };
use hyper::{body::Body, Request, Response, StatusCode}; use hyper::{body::Body, Request, Response, StatusCode};
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -27,6 +29,16 @@ use crate::s3::error::*;
const X_AMZ_MP_PARTS_COUNT: &str = "x-amz-mp-parts-count"; const X_AMZ_MP_PARTS_COUNT: &str = "x-amz-mp-parts-count";
#[derive(Default)]
pub struct GetObjectOverrides {
pub(crate) response_cache_control: Option<String>,
pub(crate) response_content_disposition: Option<String>,
pub(crate) response_content_encoding: Option<String>,
pub(crate) response_content_language: Option<String>,
pub(crate) response_content_type: Option<String>,
pub(crate) response_expires: Option<String>,
}
fn object_headers( fn object_headers(
version: &ObjectVersion, version: &ObjectVersion,
version_meta: &ObjectVersionMeta, version_meta: &ObjectVersionMeta,
@ -52,6 +64,32 @@ fn object_headers(
resp resp
} }
/// Override headers according to specific query parameters, see
/// section "Overriding response header values through the request" in
/// https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html
fn getobject_override_headers(
overrides: GetObjectOverrides,
resp: &mut http::response::Builder,
) -> Result<(), Error> {
// TODO: this only applies for signed requests, so when we support
// anonymous access in the future we will have to do a permission check here
let overrides = [
(CACHE_CONTROL, overrides.response_cache_control),
(CONTENT_DISPOSITION, overrides.response_content_disposition),
(CONTENT_ENCODING, overrides.response_content_encoding),
(CONTENT_LANGUAGE, overrides.response_content_language),
(CONTENT_TYPE, overrides.response_content_type),
(EXPIRES, overrides.response_expires),
];
for (hdr, val_opt) in overrides {
if let Some(val) = val_opt {
let val = val.try_into().ok_or_bad_request("invalid header value")?;
resp.headers_mut().unwrap().insert(hdr, val);
}
}
Ok(())
}
fn try_answer_cached( fn try_answer_cached(
version: &ObjectVersion, version: &ObjectVersion,
version_meta: &ObjectVersionMeta, version_meta: &ObjectVersionMeta,
@ -185,6 +223,7 @@ pub async fn handle_get(
bucket_id: Uuid, bucket_id: Uuid,
key: &str, key: &str,
part_number: Option<u64>, part_number: Option<u64>,
overrides: GetObjectOverrides,
) -> Result<Response<ResBody>, Error> { ) -> Result<Response<ResBody>, Error> {
let object = garage let object = garage
.object_table .object_table
@ -236,9 +275,10 @@ pub async fn handle_get(
(None, None) => (), (None, None) => (),
} }
let resp_builder = object_headers(last_v, last_v_meta) let mut resp_builder = object_headers(last_v, last_v_meta)
.header(CONTENT_LENGTH, format!("{}", last_v_meta.size)) .header(CONTENT_LENGTH, format!("{}", last_v_meta.size))
.status(StatusCode::OK); .status(StatusCode::OK);
getobject_override_headers(overrides, &mut resp_builder)?;
match &last_v_data { match &last_v_data {
ObjectVersionData::DeleteMarker => unreachable!(), ObjectVersionData::DeleteMarker => unreachable!(),
@ -303,6 +343,9 @@ async fn handle_get_range(
begin: u64, begin: u64,
end: u64, end: u64,
) -> Result<Response<ResBody>, Error> { ) -> Result<Response<ResBody>, Error> {
// Here we do not use getobject_override_headers because we don't
// want to add any overridden headers (those should not be added
// when returning PARTIAL_CONTENT)
let resp_builder = object_headers(version, version_meta) let resp_builder = object_headers(version, version_meta)
.header(CONTENT_LENGTH, format!("{}", end - begin)) .header(CONTENT_LENGTH, format!("{}", end - begin))
.header( .header(
@ -343,6 +386,7 @@ async fn handle_get_part(
version_meta: &ObjectVersionMeta, version_meta: &ObjectVersionMeta,
part_number: u64, part_number: u64,
) -> Result<Response<ResBody>, Error> { ) -> Result<Response<ResBody>, Error> {
// Same as for get_range, no getobject_override_headers
let resp_builder = let resp_builder =
object_headers(object_version, version_meta).status(StatusCode::PARTIAL_CONTENT); object_headers(object_version, version_meta).status(StatusCode::PARTIAL_CONTENT);

View file

@ -125,6 +125,12 @@ pub enum Endpoint {
key: String, key: String,
part_number: Option<u64>, part_number: Option<u64>,
version_id: Option<String>, version_id: Option<String>,
response_cache_control: Option<String>,
response_content_disposition: Option<String>,
response_content_encoding: Option<String>,
response_content_language: Option<String>,
response_content_type: Option<String>,
response_expires: Option<String>,
}, },
GetObjectAcl { GetObjectAcl {
key: String, key: String,
@ -358,7 +364,14 @@ impl Endpoint {
(query.keyword.take().unwrap_or_default(), key, query, None), (query.keyword.take().unwrap_or_default(), key, query, None),
key: [ key: [
EMPTY if upload_id => ListParts (query::upload_id, opt_parse::max_parts, opt_parse::part_number_marker), EMPTY if upload_id => ListParts (query::upload_id, opt_parse::max_parts, opt_parse::part_number_marker),
EMPTY => GetObject (query_opt::version_id, opt_parse::part_number), EMPTY => GetObject (query_opt::version_id,
opt_parse::part_number,
query_opt::response_cache_control,
query_opt::response_content_disposition,
query_opt::response_content_encoding,
query_opt::response_content_language,
query_opt::response_content_type,
query_opt::response_expires),
ACL => GetObjectAcl (query_opt::version_id), ACL => GetObjectAcl (query_opt::version_id),
LEGAL_HOLD => GetObjectLegalHold (query_opt::version_id), LEGAL_HOLD => GetObjectLegalHold (query_opt::version_id),
RETENTION => GetObjectRetention (query_opt::version_id), RETENTION => GetObjectRetention (query_opt::version_id),
@ -671,6 +684,12 @@ generateQueryParameters! {
"partNumber" => part_number, "partNumber" => part_number,
"part-number-marker" => part_number_marker, "part-number-marker" => part_number_marker,
"prefix" => prefix, "prefix" => prefix,
"response-cache-control" => response_cache_control,
"response-content-disposition" => response_content_disposition,
"response-content-encoding" => response_content_encoding,
"response-content-language" => response_content_language,
"response-content-type" => response_content_type,
"response-expires" => response_expires,
"select-type" => select_type, "select-type" => select_type,
"start-after" => start_after, "start-after" => start_after,
"uploadId" => upload_id, "uploadId" => upload_id,

View file

@ -185,6 +185,30 @@ async fn test_getobject() {
assert_eq!(o.content_range.unwrap().as_str(), "bytes 57-61/62"); assert_eq!(o.content_range.unwrap().as_str(), "bytes 57-61/62");
assert_bytes_eq!(o.body, &BODY[57..]); assert_bytes_eq!(o.body, &BODY[57..]);
} }
{
let exp = aws_sdk_s3::primitives::DateTime::from_secs(10000000000);
let o = ctx
.client
.get_object()
.bucket(&bucket)
.key(STD_KEY)
.response_content_type("application/x-dummy-test")
.response_cache_control("ccdummy")
.response_content_disposition("cddummy")
.response_content_encoding("cedummy")
.response_content_language("cldummy")
.response_expires(exp)
.send()
.await
.unwrap();
assert_eq!(o.content_type.unwrap().as_str(), "application/x-dummy-test");
assert_eq!(o.cache_control.unwrap().as_str(), "ccdummy");
assert_eq!(o.content_disposition.unwrap().as_str(), "cddummy");
assert_eq!(o.content_encoding.unwrap().as_str(), "cedummy");
assert_eq!(o.content_language.unwrap().as_str(), "cldummy");
assert_eq!(o.expires.unwrap(), exp);
assert_bytes_eq!(o.body, &BODY[..]);
}
} }
#[tokio::test] #[tokio::test]

View file

@ -247,7 +247,17 @@ impl WebServer {
.map_err(ApiError::from) .map_err(ApiError::from)
.map(|res| res.map(|_empty_body: EmptyBody| empty_body())), .map(|res| res.map(|_empty_body: EmptyBody| empty_body())),
Method::HEAD => handle_head(self.garage.clone(), &req, bucket_id, &key, None).await, Method::HEAD => handle_head(self.garage.clone(), &req, bucket_id, &key, None).await,
Method::GET => handle_get(self.garage.clone(), &req, bucket_id, &key, None).await, Method::GET => {
handle_get(
self.garage.clone(),
&req,
bucket_id,
&key,
None,
Default::default(),
)
.await
}
_ => Err(ApiError::bad_request("HTTP method not supported")), _ => Err(ApiError::bad_request("HTTP method not supported")),
}; };
@ -291,7 +301,15 @@ impl WebServer {
.body(empty_body::<Infallible>()) .body(empty_body::<Infallible>())
.unwrap(); .unwrap();
match handle_get(self.garage.clone(), &req2, bucket_id, &error_document, None).await match handle_get(
self.garage.clone(),
&req2,
bucket_id,
&error_document,
None,
Default::default(),
)
.await
{ {
Ok(mut error_doc) => { Ok(mut error_doc) => {
// The error won't be logged back in handle_request, // The error won't be logged back in handle_request,