Add ability to set host header

This commit is contained in:
asonix 2020-09-07 16:42:06 -05:00
parent 7538969a0b
commit 2c031a7a1d
3 changed files with 41 additions and 2 deletions

View file

@ -37,7 +37,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
std::env::set_var("RUST_LOG", "info");
pretty_env_logger::init();
let config = Config::default().require_header("accept");
let config = Config::default().require_header("accept").set_host_header();
request(config.clone()).await?;
request(config.dont_use_created_field()).await?;

View file

@ -279,6 +279,9 @@ pub trait Sign {
pub struct Config {
/// The inner config type
config: http_signature_normalization::Config,
/// Whether to set the Host header
set_host: bool,
}
#[derive(Debug, thiserror::Error)]
@ -315,6 +318,10 @@ pub enum PrepareSignError {
#[error("{0}")]
/// Some headers were marked as required, but are missing
RequiredError(#[from] RequiredError),
#[error("No host provided for URL, {0}")]
/// Missing host
Host(String),
}
impl From<http_signature_normalization::PrepareVerifyError> for PrepareVerifyError {
@ -342,6 +349,16 @@ impl Config {
Config::default()
}
/// Since manually setting the Host header doesn't work so well in AWC, you can use this method
/// to enable setting the Host header for signing requests without breaking client
/// functionality
pub fn set_host_header(self) -> Self {
Config {
config: self.config,
set_host: true,
}
}
/// Opt out of using the (created) and (expires) fields introduced in draft 11
///
/// Use this for compatibility with mastodon
@ -351,6 +368,7 @@ impl Config {
pub fn dont_use_created_field(self) -> Self {
Config {
config: self.config.dont_use_created_field(),
set_host: self.set_host,
}
}
@ -358,6 +376,7 @@ impl Config {
pub fn set_expiration(self, expires_after: Duration) -> Self {
Config {
config: self.config.set_expiration(expires_after),
set_host: self.set_host,
}
}
@ -365,6 +384,7 @@ impl Config {
pub fn require_header(self, header: &str) -> Self {
Config {
config: self.config.require_header(header),
set_host: self.set_host,
}
}

View file

@ -66,10 +66,29 @@ where
E: From<BlockingError<E>> + From<PrepareSignError> + std::fmt::Debug + Send + 'static,
K: Display,
{
let mut headers = request.headers().clone();
if config.set_host {
let header_string = request
.get_uri()
.host()
.ok_or_else(|| PrepareSignError::Host(request.get_uri().to_string()))?
.to_string();
let header_string = match request.get_uri().port().map(|p| p.as_u16()) {
None | Some(443) | Some(80) => header_string,
Some(port) => format!("{}:{}", header_string, port),
};
headers.insert(
"Host".parse().unwrap(),
header_string
.parse()
.map_err(|_| PrepareSignError::Host(request.get_uri().to_string()))?,
);
}
let unsigned = config.begin_sign(
request.get_method(),
request.get_uri().path_and_query(),
request.headers().clone(),
headers,
)?;
let key_id = key_id.to_string();