Add an ApiToken model, and an endpoint to get one

This commit is contained in:
Baptiste Gelez 2018-10-22 14:30:04 +01:00
parent f2190adfc2
commit 2394ff424b
11 changed files with 148 additions and 12 deletions

View file

@ -0,0 +1,2 @@
-- This file should undo anything in `up.sql`
DROP TABLE api_tokens;

View file

@ -0,0 +1,9 @@
-- Your SQL goes here
CREATE TABLE api_tokens (
id SERIAL PRIMARY KEY,
creation_date TIMESTAMP NOT NULL DEFAULT now(),
value TEXT NOT NULL,
scopes TEXT NOT NULL,
app_id INTEGER NOT NULL REFERENCES apps(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE
)

View file

@ -0,0 +1,2 @@
-- This file should undo anything in `up.sql`
DROP TABLE api_tokens;

View file

@ -0,0 +1,9 @@
-- Your SQL goes here
CREATE TABLE api_tokens (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
creation_date DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
value TEXT NOT NULL,
scopes TEXT NOT NULL,
app_id INTEGER NOT NULL REFERENCES apps(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE
)

View file

@ -1,5 +1,6 @@
use gettextrs::gettext;
use heck::CamelCase;
use openssl::rand::rand_bytes;
use pulldown_cmark::{Event, Parser, Options, Tag, html};
use rocket::{
http::uri::Uri,
@ -7,6 +8,13 @@ use rocket::{
};
use std::collections::HashSet;
/// Generates an hexadecimal representation of 32 bytes of random data
pub fn random_hex() -> String {
let mut bytes = [0; 32];
rand_bytes(&mut bytes).expect("Error while generating client id");
bytes.into_iter().fold(String::new(), |res, byte| format!("{}{:x}", res, byte))
}
/// Remove non alphanumeric characters and CamelCase a string
pub fn make_actor_id(name: String) -> String {
name.as_str()

View file

@ -0,0 +1,40 @@
use chrono::NaiveDateTime;
use diesel::{self, ExpressionMethods, QueryDsl, RunQueryDsl};
use schema::api_tokens;
#[derive(Clone, Queryable)]
pub struct ApiToken {
pub id: i32,
pub creation_date: NaiveDateTime,
pub value: String,
/// Scopes, separated by +
/// Global scopes are read and write
/// and both can be limited to an endpoint by affixing them with :ENDPOINT
///
/// Examples :
///
/// read
/// read+write
/// read:posts
/// read:posts+write:posts
pub scopes: String,
pub app_id: i32,
pub user_id: i32,
}
#[derive(Insertable)]
#[table_name = "api_tokens"]
pub struct NewApiToken {
pub value: String,
pub scopes: String,
pub app_id: i32,
pub user_id: i32,
}
impl ApiToken {
get!(api_tokens);
insert!(api_tokens, NewApiToken);
find_by!(api_tokens, find_by_value, value as String);
}

View file

@ -1,9 +1,9 @@
use canapi::{Error, Provider};
use chrono::NaiveDateTime;
use diesel::{self, RunQueryDsl, QueryDsl, ExpressionMethods};
use openssl::rand::rand_bytes;
use plume_api::apps::AppEndpoint;
use plume_common::utils::random_hex;
use Connection;
use schema::apps;
@ -14,7 +14,7 @@ pub struct App {
pub client_id: String,
pub client_secret: String,
pub redirect_uri: Option<String>,
pub website: Option<String>,
pub website: Option<String>,
pub creation_date: NaiveDateTime,
}
@ -25,7 +25,7 @@ pub struct NewApp {
pub client_id: String,
pub client_secret: String,
pub redirect_uri: Option<String>,
pub website: Option<String>,
pub website: Option<String>,
}
impl Provider<Connection> for App {
@ -40,13 +40,9 @@ impl Provider<Connection> for App {
}
fn create(conn: &Connection, data: AppEndpoint) -> Result<AppEndpoint, Error> {
let mut id = [0; 32];
rand_bytes(&mut id).expect("Error while generating client id");
let client_id = id.into_iter().fold(String::new(), |res, byte| format!("{}{:x}", res, byte));
let mut secret = [0; 32];
rand_bytes(&mut secret).expect("Error while generating client secret");
let client_secret = secret.into_iter().fold(String::new(), |res, byte| format!("{}{:x}", res, byte));
let client_id = random_hex();
let client_secret = random_hex();
let app = App::insert(conn, NewApp {
name: data.name.expect("App::create: name is required"),
client_id: client_id,
@ -68,7 +64,7 @@ impl Provider<Connection> for App {
fn update(conn: &Connection, id: i32, new_data: AppEndpoint) -> Result<AppEndpoint, Error> {
unimplemented!()
}
fn delete(conn: &Connection, id: i32) {
unimplemented!()
}
@ -77,4 +73,5 @@ impl Provider<Connection> for App {
impl App {
get!(apps);
insert!(apps, NewApp);
}
find_by!(apps, find_by_client_id, client_id as String);
}

View file

@ -214,6 +214,7 @@ pub fn ap_url(url: String) -> String {
}
pub mod admin;
pub mod api_tokens;
pub mod apps;
pub mod blog_authors;
pub mod blogs;

View file

@ -1,3 +1,14 @@
table! {
api_tokens (id) {
id -> Int4,
creation_date -> Timestamp,
value -> Text,
scopes -> Text,
app_id -> Int4,
user_id -> Int4,
}
}
table! {
apps (id) {
id -> Int4,
@ -184,6 +195,8 @@ table! {
}
}
joinable!(api_tokens -> apps (app_id));
joinable!(api_tokens -> users (user_id));
joinable!(blog_authors -> blogs (blog_id));
joinable!(blog_authors -> users (author_id));
joinable!(blogs -> instances (instance_id));
@ -204,6 +217,7 @@ joinable!(tags -> posts (post_id));
joinable!(users -> instances (instance_id));
allow_tables_to_appear_in_same_query!(
api_tokens,
apps,
blog_authors,
blogs,

View file

@ -1,2 +1,54 @@
use rocket_contrib::Json;
use serde_json;
use plume_common::utils::random_hex;
use plume_models::{
apps::App,
api_tokens::*,
db_conn::DbConn,
users::User,
};
#[derive(FromForm)]
struct OAuthRequest {
client_id: String,
client_secret: String,
password: String,
username: String,
scopes: String,
}
#[get("/oauth2?<query>")]
fn oauth(query: OAuthRequest, conn: DbConn) -> Json<serde_json::Value> {
let app = App::find_by_client_id(&*conn, query.client_id).expect("OAuth request from unknown client");
if app.client_secret == query.client_secret {
if let Some(user) = User::find_local(&*conn, query.username) {
if user.auth(query.password) {
let token = ApiToken::insert(&*conn, NewApiToken {
app_id: app.id,
user_id: user.id,
value: random_hex(),
scopes: query.scopes,
});
Json(json!({
"token": token.value
}))
} else {
Json(json!({
"error": "Wrong password"
}))
}
} else {
Json(json!({
"error": "Unknown user"
}))
}
} else {
Json(json!({
"error": "Invalid client_secret"
}))
}
}
pub mod apps;
pub mod posts;

View file

@ -156,6 +156,8 @@ fn main() {
routes::errors::csrf_violation
])
.mount("/api/v1", routes![
api::oauth,
api::apps::create,
api::posts::get,