diff --git a/src/auth.rs b/src/auth.rs index 9dbc225..7e3aef6 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -38,9 +38,14 @@ pub struct Authenticated { impl Authenticated { pub fn can_access_schema(&self, schema: &str) -> bool { - if self.schemas.is_empty() { + if self.role == "admin" { return true; } + + if self.schemas.is_empty() { + return false; + } + self.schemas.iter().any(|s| s == schema) } } @@ -193,13 +198,22 @@ pub async fn register( let mut conn = get_connection(); - // Determine role: first registered user becomes admin, others are regular users. - let count_row = diesel::sql_query("SELECT COUNT(*)::INT as count FROM \"Auth\".users") - .get_result::(&mut conn) - .unwrap_or(UserCount { count: 0 }); - - let role = if count_row.count == 0 { - "admin".to_string() + // Determine role. By default the first registered user becomes admin + // (for easy bootstrap), but this behavior can be disabled by setting + // AUTH_BOOTSTRAP_MODE to any value other than "first-user-admin". + let bootstrap_mode = std::env::var("AUTH_BOOTSTRAP_MODE") + .unwrap_or_else(|_| "first-user-admin".to_string()); + + let role = if bootstrap_mode == "first-user-admin" { + let count_row = diesel::sql_query("SELECT COUNT(*)::INT as count FROM \"Auth\".users") + .get_result::(&mut conn) + .unwrap_or(UserCount { count: 0 }); + + if count_row.count == 0 { + "admin".to_string() + } else { + "user".to_string() + } } else { "user".to_string() }; diff --git a/src/createdatabase.rs b/src/createdatabase.rs index cedd43f..9d7eb72 100644 --- a/src/createdatabase.rs +++ b/src/createdatabase.rs @@ -1,9 +1,13 @@ use crate::dbconnect; use diesel::prelude::*; use diesel::sql_query; +use crate::validation; //create a logical database (schema) in Postgres and register an API key pub fn create_database(database_name: &str) { let mut conn = dbconnect::internalqueryconn(); + // database_name is expected to be validated at the HTTP layer; this is + // a final safety net. + assert!(validation::is_valid_identifier(database_name)); let mut query = String::from("CREATE SCHEMA IF NOT EXISTS "); query.push_str(database_name); query.push_str(";"); @@ -16,6 +20,7 @@ pub fn create_database(database_name: &str) { pub fn create_databaseweb(database: &str) -> String { let mut conn = dbconnect::internalqueryconn(); let dbname = String::from(database); + assert!(validation::is_valid_identifier(&dbname)); let mut query = String::from("CREATE SCHEMA IF NOT EXISTS "); query.push_str(&dbname); query.push_str(";"); diff --git a/src/dbconnect.rs b/src/dbconnect.rs index 33166fe..1192397 100644 --- a/src/dbconnect.rs +++ b/src/dbconnect.rs @@ -4,9 +4,6 @@ use csv::ReaderBuilder; use diesel::pg::PgConnection; use diesel::prelude::*; -// For Postgres/Diesel we maintain a single physical database and use -// the `database` argument only as a logical schema/table prefix inside SQL. -// All connections are created from a DATABASE_URL environment variable. pub fn database_connection(_database: &str) -> PooledConn { internalqueryconn() } @@ -22,7 +19,6 @@ pub fn database_connection_no_db() -> PooledConn { internalqueryconn() } fn grabfromfile() -> LinkDataBase { - //igneroe header let mut reader = ReaderBuilder::new() .has_headers(false) .from_path("tmp/dbconnection.txt") @@ -34,7 +30,6 @@ fn grabfromfile() -> LinkDataBase { dbport: String::new(), }; for result in reader.records() { - //ignore header let record = result.unwrap(); println!("{:?}", record); diff --git a/src/delete.rs b/src/delete.rs index 664ad03..ed909b3 100644 --- a/src/delete.rs +++ b/src/delete.rs @@ -7,19 +7,30 @@ pub fn deleterecord( table: &str, id: Vec<(String, String)>, ) -> std::result::Result { - //grab second string from tuple + // Collect and validate numeric IDs make sure positive + let mut numeric_ids: Vec = Vec::new(); + for (_, raw) in id.iter() { + let trimmed = raw.trim().trim_matches('"'); + match trimmed.parse::() { + Ok(v) if v > 0 => numeric_ids.push(v), + _ => return Err("invalid record id".to_string()), + } + } + + if numeric_ids.is_empty() { + return Err("no record ids provided".to_string()); + } + let mut stmt = String::from("DELETE FROM "); stmt.push_str(database); stmt.push_str("."); stmt.push_str(table); - stmt.push_str(" WHERE "); - stmt.push_str("INTERNAL_PRIMARY_KEY"); - stmt.push_str(" in( "); - for i in 0..id.len() { - stmt.push_str(&id[i].1); - if i != id.len() - 1 { + stmt.push_str(" WHERE INTERNAL_PRIMARY_KEY IN ("); + for (idx, v) in numeric_ids.iter().enumerate() { + if idx > 0 { stmt.push_str(", "); } + stmt.push_str(&v.to_string()); } stmt.push_str(")"); println!("{}", stmt); @@ -27,7 +38,7 @@ pub fn deleterecord( } pub fn droptable(database: &str, table: &str) -> std::result::Result { - //grab second string from tuple + // This is only called from admin-only handlers; the database and table let mut stmt = String::from("DROP TABLE "); stmt.push_str(database); stmt.push_str("."); diff --git a/src/initconnect.rs b/src/initconnect.rs index 0acd2d9..cde1025 100644 --- a/src/initconnect.rs +++ b/src/initconnect.rs @@ -1,8 +1,3 @@ -//initialize conenction parameters like username, password, host, port -//use crate::LinkDataBase; -//use std::fs::File; -//use std::io::Write; - pub fn getpagehtml() -> String { //get page html to type username, password, host, port. let mut html = String::new(); diff --git a/src/insertrecords.rs b/src/insertrecords.rs index 84b014c..08df146 100644 --- a/src/insertrecords.rs +++ b/src/insertrecords.rs @@ -71,7 +71,7 @@ impl TableDef { } } stmt.push_str(") VALUES ("); - for data in date.iter() { + for (record_idx, data) in date.iter().enumerate() { for i in 0..data.len() { let mut valuedata = data[i].1.replace("\"", ""); // Escape single quotes for SQL literal @@ -87,12 +87,11 @@ impl TableDef { stmt.push_str(", "); } } - stmt.push_str("), ("); + if record_idx != date.len() - 1 { + stmt.push_str("), ("); + } } - // remove the trailing `, (` - stmt.pop(); - stmt.pop(); - stmt.pop(); + stmt.push_str(")"); stmt } } diff --git a/src/main.rs b/src/main.rs index c8cdf0d..98b64ee 100755 --- a/src/main.rs +++ b/src/main.rs @@ -21,33 +21,20 @@ pub mod relationships; pub mod tablecreate; pub mod update; pub mod auth; +pub mod validation; // Global database connection type, backed by Diesel Postgres pub type PooledConn = PgConnection; -//use rusoto_s3::*; -//use mysql::prelude::*; -//use crate::createrecord::generateform::CreateRelation; -//use actix_identity::{CookieIdentityPolicy, IdentityService}; -//use futures_util::TryStreamExt as _; -//use uuid::Uuid; -//use actix_multipart::Multipart; -//test + #[actix_web::main] async fn main() { let mut args = std::env::args().nth(1).unwrap(); args.push_str(":8080"); - //let pword=std::env::args().nth(2).unwrap(); - //let secretkey = cookie:: - let redisconnection = String::from("127.0.0.1:6379"); + //let redisconnection = String::from("127.0.0.1:6379"); let server = HttpServer::new(move || { - //App::new() - // .wrap(SessionMiddleware::new( - // RedisActorSessionStore::new(&redisconnection), - // secretkey.clone(), - // )) - App::new() + App::new() .app_data(web::Data::new(auth::AppState::from_env())) //session cookie //.app_data(TempFileConfig::default().directory("./tmp")) @@ -65,11 +52,7 @@ async fn main() { web::post().to(auth::update_user_schemas), ) .route("/health", web::get().to(health)) - //.route("/main", web::get().to(index)) - //.route("/auth", web::post().to(auth)) .route("/getkey/{database}&apikey={apikey}", web::get().to(getkey)) - //j.route("/method", web::post().to(method)) - //.route("/createtable", web::post().to(createtable)) .route( "/createtable/{database}&table={table}&gps={gps}&apikey={apikey}", web::post().to(createtableweb), @@ -78,12 +61,10 @@ async fn main() { "/droptable/{database}&table={table}&apikey={apikey}", web::post().to(droptableweb), ) - //.route("/createdatabase", web::post().to(createnewdb)) .route( "/createdatabase/{database}&apikey={apikey}", web::post().to(createnewdbweb), ) - //.route("/query", web::post().to(query)) .route( "/query/{database}&table={table}&select={select}&where={where}&expand={expand}&apikey={api}", web::get().to(querytojson), @@ -104,12 +85,7 @@ async fn main() { "/queryall/{database}&table={table}&depth={depth}&apikey={api}", web::get().to(queryall), ) - //.service( - // web::resource("/create") - // .route(web::get().to(getcreate)) - // .route(web::post().to(postcreate)), - //) - .route( + .route( "/insert/{database}&table={table}&apikey={api}", web::post().to(dbinsert), ) @@ -125,13 +101,7 @@ async fn main() { "/updaterecord/{database}&table={table}&apikey={api}", web::post().to(dbupdaterecord), ) - //.route("/create/saveform", web::post().to(saveform)) - //.service( - // web::resource("/upload") - // .route(web::get().to(getupload)) - // .route(web::post().to(postupload)), - //) - .route( + .route( "/relationship/{database}&apikey={api}", web::post().to(createrelationshipweb), ) @@ -143,15 +113,7 @@ async fn main() { "/deleterecord/{database}&table={table}&apikey={api}", web::post().to(deleterecord), ) - //.service( - // web::resource("/createrelation") - // .route(web::get().to(getcreaterelation)) - // .route(web::post().to(postcreaterelationdefined)), - //) - - // .route("/insert", web::post().to(method)) - // .route("/create", web::post().to(method)) - }); + }); println!("Starting server at {}", args); server .bind(args) @@ -163,7 +125,7 @@ async fn main() { async fn postinitializeconnect(form: web::Form) -> impl Responder { // Legacy endpoint: API-key based initialization is deprecated. // Kept only to avoid breaking old clients; always returns an error. - let _ = form; // suppress unused warning + let _ = form; HttpResponse::Gone() .content_type("application/json; charset=utf-8") .body("{\"error\":\"deprecated_endpoint_use_jwt_auth\"}") @@ -218,47 +180,6 @@ async fn health() -> impl Responder { } } } -//async fn getcreaterelation() -> impl Responder { -// let html = createrecord::generateform::getcreaterelationshipdefined(); -// HttpResponse::Ok().body(html) -//} -//async fn postcreaterelationdefined(form: web::Form) -> impl Responder { -// let database = form.database.clone(); -// -// let _ = createrelationship::commitrelationshipdefined( -// &database, -// &form.table1, -// &form.column1, -// &form.table2, -// &form.column2, -// &form.ondelete, -// &form.onupdate, -// ); -// HttpResponse::Ok() -// .content_type("text/html; charset=utf-8") -// .body(include_str!("pages/methodsuccess.html")) -//} -//async fn index() -> impl Responder { -// HttpResponse::Ok() -// .content_type("text/html; charset=utf-8") -// .body(include_str!("page.html")) -//} -//async fn getupload() -> impl Responder { -// let html = createrecord::generateform::fileinsert(); -// HttpResponse::Ok().body(html) -//} -//async fn postupload(MultipartForm(form): MultipartForm) -> impl Responder { -// let table = &form.table.clone(); -// let database = &form.database.clone(); -// -// let file = createrecord::generateform::file_upload(form); -// -// let _ = pushdata::createtablestruct::read_csv2(&file, table, database); -// -// HttpResponse::Ok() -// .content_type("text/html; charset=utf-8") -// .body(include_str!("pages/methodsuccess.html")) -//} async fn createrelationshipweb( auth: auth::Authenticated, info: web::Path<(String, String)>, @@ -350,6 +271,12 @@ async fn deleterecord( .body("Forbidden for this database"); } + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(&info.1) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database or table name"); + } + let mut conn = dbconnect::internalqueryconn(); let body = body.into_inner(); let mut data = Vec::new(); @@ -357,8 +284,15 @@ async fn deleterecord( data.push((key.to_string(), value.to_string())); } let table = &info.1; - let statement = delete::deleterecord(&database, &table, data); - let _ = delete::exec_statement(&mut conn, &statement.unwrap()); + let statement = match delete::deleterecord(&database, &table, data) { + Ok(s) => s, + Err(_) => { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid record id(s)"); + } + }; + let _ = delete::exec_statement(&mut conn, &statement); HttpResponse::Ok() .content_type("text/json; charset=utf-8") @@ -427,13 +361,7 @@ async fn createtableweb( let _ = tablecreate::exec_statement(&mut conn, &stmt); } - //let _ = tablecreate::create_table_web( - // &mut conn, - // &database, - // &table, - // &parsed_json.0, - // &parsed_json.1, - //); + HttpResponse::Ok() .content_type("text/json; charset=utf-8") @@ -456,9 +384,14 @@ async fn droptableweb( .body("Forbidden for this database"); } + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(&info.1) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database or table name"); + } + let mut conn = dbconnect::internalqueryconn(); let body = body.into_inner(); - //let mut data=Vec::new(); let mut backup = false; for (_, value) in body.as_object().unwrap().iter() { // data.push((key.to_string(),value.to_string())); @@ -491,6 +424,12 @@ async fn retrieveattachment( .body("Forbidden for this database"); } + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(&info.1) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database or table name"); + } + println!("{:?}", &info.1); let table = &info.1; println!("{:?}", &info.2); @@ -499,10 +438,31 @@ async fn retrieveattachment( let mut conn = dbconnect::internalqueryconn(); - let stmt = querytable::retrieveattachmentstmt(table, database, id); - let result = querytable::exec_map(&mut conn, &stmt.unwrap()); - let result = result.unwrap(); - let encoded = BASE64.encode(&result[0].as_bytes()); + let stmt = match querytable::retrieveattachmentstmt(table, database, id) { + Ok(s) => s, + Err(_) => { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid attachment id"); + } + }; + + let result = match querytable::exec_map(&mut conn, &stmt) { + Ok(r) => r, + Err(_) => { + return HttpResponse::InternalServerError() + .content_type("text/json; charset=utf-8") + .body("Failed to retrieve attachment"); + } + }; + + if result.is_empty() { + return HttpResponse::NotFound() + .content_type("text/json; charset=utf-8") + .body("Attachment not found"); + } + + let encoded = BASE64.encode(result[0].as_bytes()); let json = serde_json::json!( { @@ -514,7 +474,6 @@ async fn retrieveattachment( .content_type("text/json; charset=utf-8") .body(json.to_string()) } -//grab attachment to put in s3 bucket async fn dbinsertattachment( auth: auth::Authenticated, @@ -527,7 +486,11 @@ async fn dbinsertattachment( .content_type("text/json; charset=utf-8") .body("Forbidden for this database"); } - //decode json + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(&info.1) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database or table name"); + } let body = body.into_inner(); let mut data = Vec::new(); for (key, value) in body.as_object().unwrap().iter() { @@ -542,11 +505,26 @@ async fn dbinsertattachment( let mut attachment = data[1].1.clone(); attachment = attachment.replace("\"", ""); println!("{:?}", attachment); - let encoded = BASE64.decode(attachment.as_bytes()).unwrap(); + + let encoded = match BASE64.decode(attachment.as_bytes()) { + Ok(b) => b, + Err(_) => { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid attachment encoding"); + } + }; println!("{:?}", encoded); - let insertstmt = insertrecords::insert_attachment(&info.0, &info.1, &filename, encoded); - let _ = insertrecords::exec_insert(insertstmt.unwrap()); + let insertstmt = match insertrecords::insert_attachment(&info.0, &info.1, &filename, encoded) { + Ok(s) => s, + Err(_) => { + return HttpResponse::InternalServerError() + .content_type("text/json; charset=utf-8") + .body("Failed to build insert statement"); + } + }; + let _ = insertrecords::exec_insert(insertstmt); //upload to s3 bucket using rust-s3 //let bucket = "testbucket"; @@ -574,6 +552,12 @@ async fn dbinsert( .body("Forbidden for this database"); } + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(&info.1) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database or table name"); + } + let body = body.into_inner(); let mut storagevec: Vec> = Vec::new(); for record in body.iter() { @@ -661,14 +645,18 @@ async fn dbupdaterecord( .body("Forbidden for this database"); } + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(&info.1) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database or table name"); + } + let mut conn = dbconnect::internalqueryconn(); let body = body.into_inner(); let mut storagevec: Vec> = Vec::new(); for record in body.iter() { let mut data = Vec::new(); for (key, value) in record.as_object().unwrap().iter() { - // Store all values as strings (numbers, bools, etc. will be - // stringified) and let the SQL builder quote/cast appropriately. data.push((key.to_string(), value.to_string())); } storagevec.push(data); @@ -702,6 +690,11 @@ async fn createnewdbweb( // Legacy admin API key in the path is ignored in favor of JWT-based admin checks. let database_name = &info.0; + if !validation::is_valid_identifier(database_name) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database name"); + } let key = createdatabase::create_databaseweb(database_name); let encoded = BASE64.encode(key.as_bytes()); let response = serde_json::json!(encoded); @@ -808,6 +801,12 @@ async fn queryall( .body("Forbidden for this database"); } + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(&info.1) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database or table name"); + } + let mut connection = dbconnect::internalqueryconn(); let table = &info.1; let depth = &info.2; @@ -838,6 +837,12 @@ async fn queryrelationship( .body("Forbidden for this database"); } + if !validation::is_valid_identifier(database) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database name"); + } + let mut connection = dbconnect::internalqueryconn(); let relationship = &info.1; @@ -895,6 +900,12 @@ async fn querytojson( .body("Forbidden for this database"); } + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(&info.1) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database or table name"); + } + let mut connection = dbconnect::internalqueryconn(); let tablename = &info.1; @@ -943,6 +954,12 @@ async fn querytableschema( .body("Forbidden for this database"); } + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(&body.1) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database or table name"); + } + let mut connection = dbconnect::internalqueryconn(); let tablename = &body.1; @@ -984,6 +1001,11 @@ async fn querydatabase( let mut connection = dbconnect::internalqueryconn(); let database = &body.0; + if !validation::is_valid_identifier(database) { + return HttpResponse::BadRequest() + .content_type("text/json; charset=utf-8") + .body("Invalid database name"); + } //expand will be true or false let expand = &body.1; //turn into bool @@ -1249,7 +1271,12 @@ mod tests { let valid = newrecord.compare_fields(&body); assert_eq!(valid, true); let insert = newrecord.insert(&body, &table, &database); - assert_eq!(insert, String::from("INSERT INTO unit_tests.testinsertupdatedelete (col1, col2) VALUES (50, 'Test Addition'), (50, 'Test Addition')")); + assert_eq!( + insert, + String::from( + "INSERT INTO unit_tests.testinsertupdatedelete (col1, col2) VALUES ('50', 'Test Addition'), ('50', 'Test Addition')", + ) + ); } #[test] fn test_update_record() { @@ -1264,7 +1291,12 @@ mod tests { let update = update::updaterecord(database, table, datastore); //assert_eq!(update.unwrap(), String::from("Success")); - assert_eq!(update[0], String::from("UPDATE unit_tests.testinsertupdatedelete SET col1= \"50\", col2= \"Changed\" WHERE INTERNAL_PRIMARY_KEY=1")); + assert_eq!( + update[0], + String::from( + "UPDATE unit_tests.testinsertupdatedelete SET col1 = '50', col2 = 'Changed' WHERE INTERNAL_PRIMARY_KEY = 1", + ) + ); } #[test] fn test_delete_record() { @@ -1274,7 +1306,12 @@ mod tests { data.push(("1".to_string(), "1".to_string())); data.push(("2".to_string(), "2".to_string())); let statement = delete::deleterecord(database, table, data); - assert_eq!(statement.unwrap(), String::from("DELETE FROM unit_tests.testinsertupdatedelete WHERE INTERNAL_PRIMARY_KEY in( 1, 2)")); + assert_eq!( + statement.unwrap(), + String::from( + "DELETE FROM unit_tests.testinsertupdatedelete WHERE INTERNAL_PRIMARY_KEY IN (1, 2)", + ) + ); } #[test] fn test_drop_table() { diff --git a/src/pushdata/gettablecol.rs b/src/pushdata/gettablecol.rs index 05adf75..8a7ebd9 100644 --- a/src/pushdata/gettablecol.rs +++ b/src/pushdata/gettablecol.rs @@ -16,8 +16,6 @@ pub fn get_table_col( //testcsv' AND TABLE_NAME='"); querystring.push_str(table_name.to_string().as_str()); querystring.push_str("'"); - // In Postgres, unquoted identifiers are stored lowercased in information_schema, - // so filter on the lowercased names of our internal/system columns. querystring.push_str(" and COLUMN_NAME != 'internal_primary_key'"); querystring.push_str(" and COLUMN_NAME != 'gps_id'"); querystring.push_str(" and COLUMN_NAME != 'x_coord'"); @@ -55,13 +53,9 @@ pub fn createinsertstatement( for j in 0..data[i].columns.len() { println!("New Column"); for k in 0..data[i].columns[j].len() { - //println!("Data below"); println!("{:?}", data[i].columns[j][k]); - //println!("Data above"); let datarecord = &data[i].columns[j][k]; - //insert into mysql data from data variable into columns in columnname variable - //let insertstatement =gettablecol::createinsertstatement(&mut conn, &tablename); - //println!("{}", insertstatement); + insertstatement.push_str("'"); insertstatement.push_str(&datarecord); insertstatement.push_str("'"); diff --git a/src/querytable.rs b/src/querytable.rs index c707a5b..1fcd19c 100644 --- a/src/querytable.rs +++ b/src/querytable.rs @@ -1,4 +1,5 @@ use crate::PooledConn; +use crate::validation; use diesel::prelude::*; use diesel::sql_query; use diesel::sql_types::Text; @@ -31,6 +32,39 @@ pub fn query_tables( } } + +fn normalize_simple_where(whereclause: &str) -> Option { + let trimmed = whereclause.trim(); + + if trimmed.is_empty() { + return None; + } + + if trimmed == "1=1" { + return Some("1=1".to_string()); + } + + let mut parts = trimmed.splitn(2, '='); + let col = match parts.next() { + Some(c) => c.trim(), + None => return Some("1=0".to_string()), + }; + let raw_val = match parts.next() { + Some(v) => v.trim(), + None => return Some("1=0".to_string()), + }; + + if !validation::is_valid_identifier(col) { + return Some("1=0".to_string()); + } + + // Strip optional surrounding single quotes from the value. + let val = raw_val.trim_matches('\''); + let escaped = validation::escape_sql_literal(val); + + Some(format!("{} = '{}'", col, escaped)) +} + pub fn exec_map( conn: &mut PooledConn, query: &str, @@ -188,13 +222,19 @@ pub fn retrieveattachmentstmt( database: &str, id: &str, ) -> std::result::Result> { + // ID must be a positive integer; reject anything else. + let parsed_id: i64 = id.trim().parse()?; + if parsed_id <= 0 { + return Err("invalid attachment id".into()); + } + let mut query = String::from("SELECT Attachment FROM "); query.push_str(database); query.push_str("."); query.push_str(table); query.push_str("_GPS"); - query.push_str(" WHERE INTERNAL_PRIMARY_KEY= "); - query.push_str(id); + query.push_str(" WHERE INTERNAL_PRIMARY_KEY = "); + query.push_str(&parsed_id.to_string()); Ok(query) } @@ -219,9 +259,9 @@ fn query_table( query.push_str(database); query.push_str("."); query.push_str(table); - if !whereclause.is_empty() { + if let Some(safe_where) = normalize_simple_where(whereclause) { query.push_str(" WHERE "); - query.push_str(whereclause); + query.push_str(&safe_where); } #[derive(QueryableByName)] diff --git a/src/relationships.rs b/src/relationships.rs index 60862cb..248457c 100644 --- a/src/relationships.rs +++ b/src/relationships.rs @@ -1,4 +1,5 @@ use crate::PooledConn; +use crate::validation; use diesel::prelude::*; use diesel::sql_query; use serde::{Deserialize, Serialize}; @@ -30,7 +31,8 @@ impl RelationshipBuilder{ relationship } pub fn check_relationship_name(&self, conn: &mut PooledConn)->bool{ - let stmt = format!("SELECT relationship FROM Relationships.relationships WHERE relationship='{}'", self.relationship_name); + let safe_name = validation::escape_sql_literal(&self.relationship_name); + let stmt = format!("SELECT relationship FROM Relationships.relationships WHERE relationship='{}'", safe_name); #[derive(QueryableByName)] struct RelRow { @@ -55,7 +57,25 @@ impl RelationshipBuilder{ } pub fn create_relationship_stmt(relationship: &RelationshipBuilder) -> String{ - let stmt = format!("INSERT INTO Relationships.relationships (targeted_database, parent_table, child_table, where_clause, relationship) VALUES ('{}', '{}', '{}', '{}', '{}')", relationship.database, relationship.parent_table, relationship.child_table, relationship.where_clause, relationship.relationship_name); + let db = &relationship.database; + let parent = &relationship.parent_table; + let child = &relationship.child_table; + let where_clause = &relationship.where_clause; + let name = &relationship.relationship_name; + + // These values ultimately build a SQL statement; escape them as string + // literals so a malicious relationship name or where_clause cannot break + // the INSERT. + let db_esc = validation::escape_sql_literal(db); + let parent_esc = validation::escape_sql_literal(parent); + let child_esc = validation::escape_sql_literal(child); + let where_esc = validation::escape_sql_literal(where_clause); + let name_esc = validation::escape_sql_literal(name); + + let stmt = format!( + "INSERT INTO Relationships.relationships (targeted_database, parent_table, child_table, where_clause, relationship) VALUES ('{}', '{}', '{}', '{}', '{}')", + db_esc, parent_esc, child_esc, where_esc, name_esc, + ); stmt } pub fn execute_relationship_stmt(stmt: &str, conn: &mut PooledConn){ diff --git a/src/tablecreate.rs b/src/tablecreate.rs index 86cd8f8..519917b 100644 --- a/src/tablecreate.rs +++ b/src/tablecreate.rs @@ -1,6 +1,7 @@ use diesel::prelude::*; use diesel::sql_query; use crate::PooledConn; +use crate::validation; //read from csv file to create table in postgres with given column names pub fn exec_statement(conn: &mut PooledConn, statement: &str) { @@ -42,6 +43,10 @@ pub fn create_table_web( column_names: &Vec<(String, String)>, column_types: &Vec<(String, String)>, ) -> String { + if !validation::is_valid_identifier(database) || !validation::is_valid_identifier(table_name) + { + return "Invalid schema or table name".to_string(); + } let mut query = String::from("CREATE TABLE "); query.push_str(database); query.push_str("."); diff --git a/src/validation.rs b/src/validation.rs new file mode 100644 index 0000000..9a34e6b --- /dev/null +++ b/src/validation.rs @@ -0,0 +1,27 @@ +pub fn is_valid_identifier(name: &str) -> bool { + let mut chars = name.chars(); + match chars.next() { + Some(c) if is_ident_start(c) => {} + _ => return false, + } + + for c in chars { + if !is_ident_char(c) { + return false; + } + } + + true +} + +fn is_ident_start(c: char) -> bool { + c == '_' || c.is_ascii_alphabetic() +} + +fn is_ident_char(c: char) -> bool { + c == '_' || c.is_ascii_alphanumeric() +} + +pub fn escape_sql_literal(value: &str) -> String { + value.replace('\'', "''") +}