diff --git a/src/find/matchers/ls.rs b/src/find/matchers/ls.rs index 2480fbad..bf75f55d 100644 --- a/src/find/matchers/ls.rs +++ b/src/find/matchers/ls.rs @@ -7,6 +7,7 @@ use chrono::DateTime; use std::{ fs::File, io::{stderr, Write}, + sync::Arc, }; use super::{Matcher, MatcherIO, WalkEntry}; @@ -110,11 +111,11 @@ fn format_permissions(file_attributes: u32) -> String { } pub struct Ls { - output_file: Option, + output_file: Option>, } impl Ls { - pub fn new(output_file: Option) -> Self { + pub fn new(output_file: Option>) -> Self { Self { output_file } } @@ -268,7 +269,7 @@ impl Ls { impl Matcher for Ls { fn matches(&self, file_info: &WalkEntry, matcher_io: &mut MatcherIO) -> bool { if let Some(file) = &self.output_file { - self.print(file_info, matcher_io, file, true); + self.print(file_info, matcher_io, file.as_ref(), true); } else { self.print( file_info, diff --git a/src/find/matchers/mod.rs b/src/find/matchers/mod.rs index 05ad37ae..3cdf654d 100644 --- a/src/find/matchers/mod.rs +++ b/src/find/matchers/mod.rs @@ -31,6 +31,18 @@ pub mod time; mod type_matcher; mod user; +use ::regex::Regex; +use chrono::{DateTime, Datelike, NaiveDateTime, Utc}; +use fs::FileSystemMatcher; +use ls::Ls; +use std::collections::HashMap; +use std::fs::{File, Metadata}; +use std::path::Path; +use std::sync::Arc; +use std::time::SystemTime; +use std::{error::Error, str::FromStr}; +use uucore::fs::FileInformation; + use self::access::AccessMatcher; use self::delete::DeleteMatcher; use self::empty::EmptyMatcher; @@ -58,18 +70,7 @@ use self::time::{ }; use self::type_matcher::{TypeMatcher, XtypeMatcher}; use self::user::{NoUserMatcher, UserMatcher}; -use ::regex::Regex; -use chrono::{DateTime, Datelike, NaiveDateTime, Utc}; -use fs::FileSystemMatcher; -use ls::Ls; -use std::{ - error::Error, - fs::{File, Metadata}, - io::Read, - path::Path, - str::FromStr, - time::SystemTime, -}; +use std::io::Read; use super::{Config, Dependencies}; @@ -272,13 +273,45 @@ impl ComparableValue { } } +// Used on file output arguments. +// If yes, use the same file pointer. +struct FileMemoizer { + mem: HashMap>, +} +impl FileMemoizer { + fn new() -> Self { + Self { + mem: HashMap::new(), + } + } + fn get_or_create_file(&mut self, path: &str) -> Result, Box> { + let mut file_info = FileInformation::from_path(path, true); + match file_info { + Ok(info) => { + let file = self + .mem + .entry(info) + .or_insert(Arc::new(File::create(path)?)); + Ok(file.clone()) + } + Err(_) => { + let file = Arc::new(File::create(path)?); + file_info = FileInformation::from_path(path, true); + self.mem.insert(file_info?, file.clone()); + Ok(file) + } + } + } +} + /// Builds a single `AndMatcher` containing the Matcher objects corresponding /// to the passed in predicate arguments. pub fn build_top_level_matcher( args: &[&str], config: &mut Config, ) -> Result, Box> { - let (_, top_level_matcher) = (build_matcher_tree(args, config, 0, false))?; + let mut file_mem = FileMemoizer::new(); + let (_, top_level_matcher) = (build_matcher_tree(args, config, &mut file_mem, 0, false))?; // if the matcher doesn't have any side-effects, then we default to printing if !top_level_matcher.has_side_effects() { @@ -421,13 +454,6 @@ fn parse_str_to_newer_args(input: &str) -> Option<(String, String)> { } } -/// Creates a file if it doesn't exist. -/// If it does exist, it will be overwritten. -fn get_or_create_file(path: &str) -> Result> { - let file = File::create(path)?; - Ok(file) -} - /// The main "translate command-line args into a matcher" function. Will call /// itself recursively if it encounters an opening bracket. A successful return /// consists of a tuple containing the new index into the args array to use (if @@ -435,6 +461,7 @@ fn get_or_create_file(path: &str) -> Result> { fn build_matcher_tree( args: &[&str], config: &mut Config, + file_mem: &mut FileMemoizer, arg_index: usize, mut expecting_bracket: bool, ) -> Result<(usize, Box), Box> { @@ -465,7 +492,7 @@ fn build_matcher_tree( } i += 1; - let file = get_or_create_file(args[i])?; + let file = file_mem.get_or_create_file(args[i])?; Some(Printer::new(PrintDelimiter::Newline, Some(file)).into_box()) } "-fprintf" => { @@ -477,7 +504,7 @@ fn build_matcher_tree( // Args + 1: output file path // Args + 2: format string i += 1; - let file = get_or_create_file(args[i])?; + let file = file_mem.get_or_create_file(args[i])?; i += 1; Some(Printf::new(args[i], Some(file))?.into_box()) } @@ -487,7 +514,7 @@ fn build_matcher_tree( } i += 1; - let file = get_or_create_file(args[i])?; + let file = file_mem.get_or_create_file(args[i])?; Some(Printer::new(PrintDelimiter::Null, Some(file)).into_box()) } "-ls" => Some(Ls::new(None).into_box()), @@ -497,7 +524,7 @@ fn build_matcher_tree( } i += 1; - let file = get_or_create_file(args[i])?; + let file = file_mem.get_or_create_file(args[i])?; Some(Ls::new(Some(file)).into_box()) } "-true" => Some(TrueMatcher.into_box()), @@ -814,7 +841,8 @@ fn build_matcher_tree( None } "(" => { - let (new_arg_index, sub_matcher) = build_matcher_tree(args, config, i + 1, true)?; + let (new_arg_index, sub_matcher) = + build_matcher_tree(args, config, file_mem, i + 1, true)?; i = new_arg_index; Some(sub_matcher) } @@ -1797,30 +1825,4 @@ mod tests { .expect("-version should stop parsing"); assert!(config.version_requested); } - - #[test] - fn get_or_create_file_test() { - use std::fs; - - // remove file if hard link file exist. - // But you can't delete a file that doesn't exist, - // so ignore the error returned here. - let _ = fs::remove_file("test_data/get_or_create_file_test"); - - // test create file - let file = get_or_create_file("test_data/get_or_create_file_test"); - assert!(file.is_ok()); - - let file = get_or_create_file("test_data/get_or_create_file_test"); - assert!(file.is_ok()); - - // test error when file no permission - #[cfg(unix)] - { - let result = get_or_create_file("/etc/shadow"); - assert!(result.is_err()); - } - - let _ = fs::remove_file("test_data/get_or_create_file_test"); - } } diff --git a/src/find/matchers/printer.rs b/src/find/matchers/printer.rs index 5d9e4f8b..dc453892 100644 --- a/src/find/matchers/printer.rs +++ b/src/find/matchers/printer.rs @@ -6,6 +6,7 @@ use std::fs::File; use std::io::{stderr, Write}; +use std::sync::Arc; use super::{Matcher, MatcherIO, WalkEntry}; @@ -26,11 +27,11 @@ impl std::fmt::Display for PrintDelimiter { /// This matcher just prints the name of the file to stdout. pub struct Printer { delimiter: PrintDelimiter, - output_file: Option, + output_file: Option>, } impl Printer { - pub fn new(delimiter: PrintDelimiter, output_file: Option) -> Self { + pub fn new(delimiter: PrintDelimiter, output_file: Option>) -> Self { Self { delimiter, output_file, @@ -71,7 +72,7 @@ impl Printer { impl Matcher for Printer { fn matches(&self, file_info: &WalkEntry, matcher_io: &mut MatcherIO) -> bool { if let Some(file) = &self.output_file { - self.print(file_info, matcher_io, file, true); + self.print(file_info, matcher_io, file.as_ref(), true); } else { self.print( file_info, @@ -127,7 +128,7 @@ mod tests { let dev_full = File::open("/dev/full").unwrap(); let abbbc = get_dir_entry_for("./test_data/simple", "abbbc"); - let matcher = Printer::new(PrintDelimiter::Newline, Some(dev_full)); + let matcher = Printer::new(PrintDelimiter::Newline, Some(Arc::new(dev_full))); let deps = FakeDependencies::new(); assert!(matcher.matches(&abbbc, &mut deps.new_matcher_io())); diff --git a/src/find/matchers/printf.rs b/src/find/matchers/printf.rs index b6722152..09c72fea 100644 --- a/src/find/matchers/printf.rs +++ b/src/find/matchers/printf.rs @@ -7,6 +7,7 @@ use std::error::Error; use std::fs::{self, File}; use std::path::Path; +use std::sync::Arc; use std::time::SystemTime; use std::{borrow::Cow, io::Write}; @@ -572,11 +573,11 @@ fn format_directive<'entry>( /// find's printf syntax. pub struct Printf { format: FormatString, - output_file: Option, + output_file: Option>, } impl Printf { - pub fn new(format: &str, output_file: Option) -> Result> { + pub fn new(format: &str, output_file: Option>) -> Result> { Ok(Self { format: FormatString::parse(format)?, output_file, @@ -624,7 +625,7 @@ impl Printf { impl Matcher for Printf { fn matches(&self, file_info: &WalkEntry, matcher_io: &mut MatcherIO) -> bool { if let Some(file) = &self.output_file { - self.print(file_info, file); + self.print(file_info, file.as_ref()); } else { self.print(file_info, &mut *matcher_io.deps.get_output().borrow_mut()); } diff --git a/tests/find_cmd_tests.rs b/tests/find_cmd_tests.rs index c4c3692e..52f4aca3 100644 --- a/tests/find_cmd_tests.rs +++ b/tests/find_cmd_tests.rs @@ -12,6 +12,7 @@ use assert_cmd::Command; use predicates::prelude::*; use regex::Regex; use serial_test::serial; +use std::collections::HashMap; use std::fs::{self, File}; use std::io::{Read, Write}; use std::{env, io::ErrorKind}; @@ -1122,6 +1123,192 @@ fn find_fprinter() { } } +struct TestCaseData { + search_dir: &'static str, + args: Vec<&'static str>, + needle_substr: Vec<(&'static str, usize)>, +} + +#[test] +#[serial(working_dir)] +fn find_using_same_out_multiple_times() { + let cases = HashMap::from([ + ( + "fprint", + TestCaseData { + search_dir: "test_data/simple", + args: vec![ + "-fprint", + "test_data/find_fprint", + "-fprint", + "test_data/find_fprint", + ], + needle_substr: vec![ + ("test_data/simple\n", 2), + ("test_data/simple/subdir\n", 2), + ("test_data/simple/subdir/ABBBC\n", 2), + ("test_data/simple/abbbc\n", 2), + ], + }, + ), + ( + "fprint0", + TestCaseData { + search_dir: "test_data/simple", + args: vec![ + "-fprint0", + "test_data/find_fprint0", + "-fprint0", + "test_data/find_fprint0", + ], + needle_substr: vec![ + ("test_data/simple\0", 2), + ("test_data/simple/subdir\0", 2), + ("test_data/simple/subdir/ABBBC\0", 2), + ("test_data/simple/abbbc\0", 2), + ], + }, + ), + ( + "fprintf", + TestCaseData { + search_dir: "test_data/simple", + args: vec![ + "-fprintf", + "test_data/find_fprintf", + "%p\n", + "-fprintf", + "test_data/find_fprintf", + "%f\n", + ], + needle_substr: vec![ + ("test_data/simple\nsimple\n", 1), + ("test_data/simple/subdir\nsubdir\n", 1), + ("test_data/simple/subdir/ABBBC\nABBBC\n", 1), + ("test_data/simple/abbbc\nabbbc\n", 1), + ], + }, + ), + ]); + + for (key, test_data) in cases { + Command::cargo_bin("find") + .expect("found binary") + .arg(fix_up_slashes(test_data.search_dir)) + .args(test_data.args) + .assert() + .success() + .stdout(predicate::str::is_empty()) + .stderr(predicate::str::is_empty()); + + // Read the generated content + let mut f = File::open(format!("test_data/find_{key}")).unwrap(); + let mut contents = String::new(); + f.read_to_string(&mut contents).unwrap(); + + // The find output can have different order depending on the platform, so we check by substrs + for (substr, times) in test_data.needle_substr { + let substr = fix_up_slashes(substr); + assert_eq!( + &contents + .as_bytes() + .windows(substr.len()) + .filter(|&w| w == substr.as_bytes()) + .count(), + ×, + "Failes on key '{key}' substr '{substr}'" + ); + } + + let _ = fs::remove_file(format!("test_data/find_{key}")); + } + + let original_file = "test_data/find_fprint_original_links"; + assert!( + File::create(original_file).is_ok(), + "Error creating original file for symlink." + ); + + let symlink_path = "test_data/find_fprint_symlink"; + #[cfg(unix)] + { + let symlink_file = symlink(std::fs::canonicalize(original_file).unwrap(), symlink_path); + assert!( + symlink_file.is_ok(), + "Error creating symlink file. {:?}", + symlink_file + ); + } + #[cfg(windows)] + { + let symlink_file = std::os::windows::fs::symlink_file( + std::fs::canonicalize(original_file).unwrap(), + symlink_path, + ); + assert!( + symlink_file.is_ok(), + "Error creating symlink file. {:?}", + symlink_file + ); + } + + let hardlink_path = "test_data/find_fprint_hardlink"; + let hardlink_file = + std::fs::hard_link(std::fs::canonicalize(original_file).unwrap(), hardlink_path); + assert!( + hardlink_file.is_ok(), + "Error creating hardlink file. {:?}", + hardlink_file + ); + + let test = TestCaseData { + search_dir: "test_data/simple", + args: vec![ + "-fprint", + original_file, + "-fprint", + symlink_path, + "-fprint", + hardlink_path, + ], + needle_substr: vec![ + ("test_data/simple\n", 3), + ("test_data/simple/subdir\n", 3), + ("test_data/simple/subdir/ABBBC\n", 3), + ("test_data/simple/abbbc\n", 3), + ], + }; + Command::cargo_bin("find") + .expect("found binary") + .arg(fix_up_slashes(test.search_dir)) + .args(test.args) + .assert() + .success() + .stdout(predicate::str::is_empty()) + .stderr(predicate::str::is_empty()); + + let mut f = File::open(original_file).unwrap(); + let mut contents = String::new(); + f.read_to_string(&mut contents).unwrap(); + + for (substr, times) in test.needle_substr { + let substr = fix_up_slashes(substr); + assert_eq!( + &contents + .as_bytes() + .windows(substr.len()) + .filter(|&w| w == substr.as_bytes()) + .count(), + ×, + "Error at substr '{substr}'" + ); + } + + let _ = fs::remove_file(original_file); + let _ = fs::remove_file(symlink_path); + let _ = fs::remove_file(hardlink_path); +} + #[test] #[serial(working_dir)] fn find_follow() {