You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

569 lines
18KB

  1. use utils::de::fix_toml_dates;
  2. use utils::fs::{get_file_time, is_path_in_directory, read_file};
  3. use reqwest::{blocking::Client, header};
  4. use std::collections::hash_map::DefaultHasher;
  5. use std::fmt;
  6. use std::hash::{Hash, Hasher};
  7. use std::str::FromStr;
  8. use url::Url;
  9. use std::path::PathBuf;
  10. use std::sync::{Arc, Mutex};
  11. use csv::Reader;
  12. use std::collections::HashMap;
  13. use tera::{from_value, to_value, Error, Function as TeraFn, Map, Result, Value};
  14. static GET_DATA_ARGUMENT_ERROR_MESSAGE: &str =
  15. "`load_data`: requires EITHER a `path` or `url` argument";
  16. enum DataSource {
  17. Url(Url),
  18. Path(PathBuf),
  19. }
  20. #[derive(Debug)]
  21. enum OutputFormat {
  22. Toml,
  23. Json,
  24. Csv,
  25. Plain,
  26. }
  27. impl fmt::Display for OutputFormat {
  28. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  29. fmt::Debug::fmt(self, f)
  30. }
  31. }
  32. impl Hash for OutputFormat {
  33. fn hash<H: Hasher>(&self, state: &mut H) {
  34. self.to_string().hash(state);
  35. }
  36. }
  37. impl FromStr for OutputFormat {
  38. type Err = Error;
  39. fn from_str(output_format: &str) -> Result<Self> {
  40. match output_format {
  41. "toml" => Ok(OutputFormat::Toml),
  42. "csv" => Ok(OutputFormat::Csv),
  43. "json" => Ok(OutputFormat::Json),
  44. "plain" => Ok(OutputFormat::Plain),
  45. format => Err(format!("Unknown output format {}", format).into()),
  46. }
  47. }
  48. }
  49. impl OutputFormat {
  50. fn as_accept_header(&self) -> header::HeaderValue {
  51. header::HeaderValue::from_static(match self {
  52. OutputFormat::Json => "application/json",
  53. OutputFormat::Csv => "text/csv",
  54. OutputFormat::Toml => "application/toml",
  55. OutputFormat::Plain => "text/plain",
  56. })
  57. }
  58. }
  59. impl DataSource {
  60. fn from_args(
  61. path_arg: Option<String>,
  62. url_arg: Option<String>,
  63. content_path: &PathBuf,
  64. ) -> Result<Self> {
  65. if path_arg.is_some() && url_arg.is_some() {
  66. return Err(GET_DATA_ARGUMENT_ERROR_MESSAGE.into());
  67. }
  68. if let Some(path) = path_arg {
  69. let full_path = content_path.join(path);
  70. if !full_path.exists() {
  71. return Err(format!("{} doesn't exist", full_path.display()).into());
  72. }
  73. return Ok(DataSource::Path(full_path));
  74. }
  75. if let Some(url) = url_arg {
  76. return Url::parse(&url)
  77. .map(DataSource::Url)
  78. .map_err(|e| format!("Failed to parse {} as url: {}", url, e).into());
  79. }
  80. Err(GET_DATA_ARGUMENT_ERROR_MESSAGE.into())
  81. }
  82. fn get_cache_key(&self, format: &OutputFormat) -> u64 {
  83. let mut hasher = DefaultHasher::new();
  84. format.hash(&mut hasher);
  85. self.hash(&mut hasher);
  86. hasher.finish()
  87. }
  88. }
  89. impl Hash for DataSource {
  90. fn hash<H: Hasher>(&self, state: &mut H) {
  91. match self {
  92. DataSource::Url(url) => url.hash(state),
  93. DataSource::Path(path) => {
  94. path.hash(state);
  95. get_file_time(&path).expect("get file time").hash(state);
  96. }
  97. };
  98. }
  99. }
  100. fn get_data_source_from_args(
  101. content_path: &PathBuf,
  102. args: &HashMap<String, Value>,
  103. ) -> Result<DataSource> {
  104. let path_arg = optional_arg!(String, args.get("path"), GET_DATA_ARGUMENT_ERROR_MESSAGE);
  105. let url_arg = optional_arg!(String, args.get("url"), GET_DATA_ARGUMENT_ERROR_MESSAGE);
  106. DataSource::from_args(path_arg, url_arg, content_path)
  107. }
  108. fn read_data_file(base_path: &PathBuf, full_path: PathBuf) -> Result<String> {
  109. if !is_path_in_directory(&base_path, &full_path)
  110. .map_err(|e| format!("Failed to read data file {}: {}", full_path.display(), e))?
  111. {
  112. return Err(format!(
  113. "{} is not inside the base site directory {}",
  114. full_path.display(),
  115. base_path.display()
  116. )
  117. .into());
  118. }
  119. read_file(&full_path).map_err(|e| {
  120. format!("`load_data`: error {} loading file {}", full_path.to_str().unwrap(), e).into()
  121. })
  122. }
  123. fn get_output_format_from_args(
  124. args: &HashMap<String, Value>,
  125. data_source: &DataSource,
  126. ) -> Result<OutputFormat> {
  127. let format_arg = optional_arg!(
  128. String,
  129. args.get("format"),
  130. "`load_data`: `format` needs to be an argument with a string value, being one of the supported `load_data` file types (csv, json, toml, plain)"
  131. );
  132. if let Some(format) = format_arg {
  133. if format == "plain" {
  134. return Ok(OutputFormat::Plain);
  135. }
  136. return OutputFormat::from_str(&format);
  137. }
  138. let from_extension = if let DataSource::Path(path) = data_source {
  139. path.extension().map(|extension| extension.to_str().unwrap()).unwrap_or_else(|| "plain")
  140. } else {
  141. "plain"
  142. };
  143. // Always default to Plain if we don't know what it is
  144. OutputFormat::from_str(from_extension).or_else(|_| Ok(OutputFormat::Plain))
  145. }
  146. /// A Tera function to load data from a file or from a URL
  147. /// Currently the supported formats are json, toml, csv and plain text
  148. #[derive(Debug)]
  149. pub struct LoadData {
  150. base_path: PathBuf,
  151. client: Arc<Mutex<Client>>,
  152. result_cache: Arc<Mutex<HashMap<u64, Value>>>,
  153. }
  154. impl LoadData {
  155. pub fn new(base_path: PathBuf) -> Self {
  156. let client = Arc::new(Mutex::new(Client::builder().build().expect("reqwest client build")));
  157. let result_cache = Arc::new(Mutex::new(HashMap::new()));
  158. Self { base_path, client, result_cache }
  159. }
  160. }
  161. impl TeraFn for LoadData {
  162. fn call(&self, args: &HashMap<String, Value>) -> Result<Value> {
  163. let data_source = get_data_source_from_args(&self.base_path, &args)?;
  164. let file_format = get_output_format_from_args(&args, &data_source)?;
  165. let cache_key = data_source.get_cache_key(&file_format);
  166. let mut cache = self.result_cache.lock().expect("result cache lock");
  167. let response_client = self.client.lock().expect("response client lock");
  168. if let Some(cached_result) = cache.get(&cache_key) {
  169. return Ok(cached_result.clone());
  170. }
  171. let data = match data_source {
  172. DataSource::Path(path) => read_data_file(&self.base_path, path),
  173. DataSource::Url(url) => {
  174. let response = response_client
  175. .get(url.as_str())
  176. .header(header::ACCEPT, file_format.as_accept_header())
  177. .send()
  178. .and_then(|res| res.error_for_status())
  179. .map_err(|e| {
  180. format!(
  181. "Failed to request {}: {}",
  182. url,
  183. e.status().expect("response status")
  184. )
  185. })?;
  186. response
  187. .text()
  188. .map_err(|e| format!("Failed to parse response from {}: {:?}", url, e).into())
  189. }
  190. }?;
  191. let result_value: Result<Value> = match file_format {
  192. OutputFormat::Toml => load_toml(data),
  193. OutputFormat::Csv => load_csv(data),
  194. OutputFormat::Json => load_json(data),
  195. OutputFormat::Plain => to_value(data).map_err(|e| e.into()),
  196. };
  197. if let Ok(data_result) = &result_value {
  198. cache.insert(cache_key, data_result.clone());
  199. }
  200. result_value
  201. }
  202. }
  203. /// Parse a JSON string and convert it to a Tera Value
  204. fn load_json(json_data: String) -> Result<Value> {
  205. let json_content: Value =
  206. serde_json::from_str(json_data.as_str()).map_err(|e| format!("{:?}", e))?;
  207. Ok(json_content)
  208. }
  209. /// Parse a TOML string and convert it to a Tera Value
  210. fn load_toml(toml_data: String) -> Result<Value> {
  211. let toml_content: toml::Value = toml::from_str(&toml_data).map_err(|e| format!("{:?}", e))?;
  212. let toml_value = to_value(toml_content).expect("Got invalid JSON that was valid TOML somehow");
  213. match toml_value {
  214. Value::Object(m) => Ok(fix_toml_dates(m)),
  215. _ => unreachable!("Loaded something other than a TOML object"),
  216. }
  217. }
  218. /// Parse a CSV string and convert it to a Tera Value
  219. ///
  220. /// An example csv file `example.csv` could be:
  221. /// ```csv
  222. /// Number, Title
  223. /// 1,Gutenberg
  224. /// 2,Printing
  225. /// ```
  226. /// The json value output would be:
  227. /// ```json
  228. /// {
  229. /// "headers": ["Number", "Title"],
  230. /// "records": [
  231. /// ["1", "Gutenberg"],
  232. /// ["2", "Printing"]
  233. /// ],
  234. /// }
  235. /// ```
  236. fn load_csv(csv_data: String) -> Result<Value> {
  237. let mut reader = Reader::from_reader(csv_data.as_bytes());
  238. let mut csv_map = Map::new();
  239. {
  240. let hdrs = reader.headers().map_err(|e| {
  241. format!("'load_data': {} - unable to read CSV header line (line 1) for CSV file", e)
  242. })?;
  243. let headers_array = hdrs.iter().map(|v| Value::String(v.to_string())).collect();
  244. csv_map.insert(String::from("headers"), Value::Array(headers_array));
  245. }
  246. {
  247. let records = reader.records();
  248. let mut records_array: Vec<Value> = Vec::new();
  249. for result in records {
  250. let record = match result {
  251. Ok(r) => r,
  252. Err(e) => {
  253. return Err(tera::Error::chain(
  254. String::from("Error encountered when parsing csv records"),
  255. e,
  256. ));
  257. }
  258. };
  259. let mut elements_array: Vec<Value> = Vec::new();
  260. for e in record.into_iter() {
  261. elements_array.push(Value::String(String::from(e)));
  262. }
  263. records_array.push(Value::Array(elements_array));
  264. }
  265. csv_map.insert(String::from("records"), Value::Array(records_array));
  266. }
  267. let csv_value: Value = Value::Object(csv_map);
  268. to_value(csv_value).map_err(|err| err.into())
  269. }
  270. #[cfg(test)]
  271. mod tests {
  272. use super::{DataSource, LoadData, OutputFormat};
  273. use std::collections::HashMap;
  274. use std::path::PathBuf;
  275. use mockito::mock;
  276. use serde_json::json;
  277. use tera::{to_value, Function};
  278. // NOTE: HTTP mock paths below are randomly generated to avoid name
  279. // collisions. Mocks with the same path can sometimes bleed between tests
  280. // and cause them to randomly pass/fail. Please make sure to use unique
  281. // paths when adding or modifying tests that use Mockito.
  282. fn get_test_file(filename: &str) -> PathBuf {
  283. let test_files = PathBuf::from("../utils/test-files").canonicalize().unwrap();
  284. return test_files.join(filename);
  285. }
  286. #[test]
  287. fn fails_when_missing_file() {
  288. let static_fn = LoadData::new(PathBuf::from("../utils"));
  289. let mut args = HashMap::new();
  290. args.insert("path".to_string(), to_value("../../../READMEE.md").unwrap());
  291. let result = static_fn.call(&args);
  292. assert!(result.is_err());
  293. assert!(result.unwrap_err().to_string().contains("READMEE.md doesn't exist"));
  294. }
  295. #[test]
  296. fn cant_load_outside_content_dir() {
  297. let static_fn = LoadData::new(PathBuf::from(PathBuf::from("../utils")));
  298. let mut args = HashMap::new();
  299. args.insert("path".to_string(), to_value("../../README.md").unwrap());
  300. args.insert("format".to_string(), to_value("plain").unwrap());
  301. let result = static_fn.call(&args);
  302. assert!(result.is_err());
  303. assert!(result
  304. .unwrap_err()
  305. .to_string()
  306. .contains("README.md is not inside the base site directory"));
  307. }
  308. #[test]
  309. fn calculates_cache_key_for_path() {
  310. // We can't test against a fixed value, due to the fact the cache key is built from the absolute path
  311. let cache_key =
  312. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Toml);
  313. let cache_key_2 =
  314. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Toml);
  315. assert_eq!(cache_key, cache_key_2);
  316. }
  317. #[test]
  318. fn calculates_cache_key_for_url() {
  319. let _m = mock("GET", "/kr1zdgbm4y")
  320. .with_header("content-type", "text/plain")
  321. .with_body("Test")
  322. .create();
  323. let url = format!("{}{}", mockito::server_url(), "/kr1zdgbm4y");
  324. let cache_key = DataSource::Url(url.parse().unwrap()).get_cache_key(&OutputFormat::Plain);
  325. assert_eq!(cache_key, 425638486551656875);
  326. }
  327. #[test]
  328. fn different_cache_key_per_filename() {
  329. let toml_cache_key =
  330. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Toml);
  331. let json_cache_key =
  332. DataSource::Path(get_test_file("test.json")).get_cache_key(&OutputFormat::Toml);
  333. assert_ne!(toml_cache_key, json_cache_key);
  334. }
  335. #[test]
  336. fn different_cache_key_per_format() {
  337. let toml_cache_key =
  338. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Toml);
  339. let json_cache_key =
  340. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Json);
  341. assert_ne!(toml_cache_key, json_cache_key);
  342. }
  343. #[test]
  344. fn can_load_remote_data() {
  345. let _m = mock("GET", "/zpydpkjj67")
  346. .with_header("content-type", "application/json")
  347. .with_body(
  348. r#"{
  349. "test": {
  350. "foo": "bar"
  351. }
  352. }
  353. "#,
  354. )
  355. .create();
  356. let url = format!("{}{}", mockito::server_url(), "/zpydpkjj67");
  357. let static_fn = LoadData::new(PathBuf::new());
  358. let mut args = HashMap::new();
  359. args.insert("url".to_string(), to_value(&url).unwrap());
  360. args.insert("format".to_string(), to_value("json").unwrap());
  361. let result = static_fn.call(&args).unwrap();
  362. assert_eq!(result.get("test").unwrap().get("foo").unwrap(), &to_value("bar").unwrap());
  363. }
  364. #[test]
  365. fn fails_when_request_404s() {
  366. let _m = mock("GET", "/aazeow0kog")
  367. .with_status(404)
  368. .with_header("content-type", "text/plain")
  369. .with_body("Not Found")
  370. .create();
  371. let url = format!("{}{}", mockito::server_url(), "/aazeow0kog");
  372. let static_fn = LoadData::new(PathBuf::new());
  373. let mut args = HashMap::new();
  374. args.insert("url".to_string(), to_value(&url).unwrap());
  375. args.insert("format".to_string(), to_value("json").unwrap());
  376. let result = static_fn.call(&args);
  377. assert!(result.is_err());
  378. assert_eq!(
  379. result.unwrap_err().to_string(),
  380. format!("Failed to request {}: 404 Not Found", url)
  381. );
  382. }
  383. #[test]
  384. fn can_load_toml() {
  385. let static_fn = LoadData::new(PathBuf::from("../utils/test-files"));
  386. let mut args = HashMap::new();
  387. args.insert("path".to_string(), to_value("test.toml").unwrap());
  388. let result = static_fn.call(&args.clone()).unwrap();
  389. // TOML does not load in order
  390. assert_eq!(
  391. result,
  392. json!({
  393. "category": {
  394. "date": "1979-05-27T07:32:00Z",
  395. "lt1": "07:32:00",
  396. "key": "value"
  397. },
  398. })
  399. );
  400. }
  401. #[test]
  402. fn unknown_extension_defaults_to_plain() {
  403. let static_fn = LoadData::new(PathBuf::from("../utils/test-files"));
  404. let mut args = HashMap::new();
  405. args.insert("path".to_string(), to_value("test.css").unwrap());
  406. let result = static_fn.call(&args.clone()).unwrap();
  407. if cfg!(windows) {
  408. assert_eq!(result, ".hello {}\r\n",);
  409. } else {
  410. assert_eq!(result, ".hello {}\n",);
  411. };
  412. }
  413. #[test]
  414. fn can_override_known_extension_with_format() {
  415. let static_fn = LoadData::new(PathBuf::from("../utils/test-files"));
  416. let mut args = HashMap::new();
  417. args.insert("path".to_string(), to_value("test.csv").unwrap());
  418. args.insert("format".to_string(), to_value("plain").unwrap());
  419. let result = static_fn.call(&args.clone()).unwrap();
  420. if cfg!(windows) {
  421. assert_eq!(result, "Number,Title\r\n1,Gutenberg\r\n2,Printing",);
  422. } else {
  423. assert_eq!(result, "Number,Title\n1,Gutenberg\n2,Printing",);
  424. };
  425. }
  426. #[test]
  427. fn will_use_format_on_unknown_extension() {
  428. let static_fn = LoadData::new(PathBuf::from("../utils/test-files"));
  429. let mut args = HashMap::new();
  430. args.insert("path".to_string(), to_value("test.css").unwrap());
  431. args.insert("format".to_string(), to_value("plain").unwrap());
  432. let result = static_fn.call(&args.clone()).unwrap();
  433. if cfg!(windows) {
  434. assert_eq!(result, ".hello {}\r\n",);
  435. } else {
  436. assert_eq!(result, ".hello {}\n",);
  437. };
  438. }
  439. #[test]
  440. fn can_load_csv() {
  441. let static_fn = LoadData::new(PathBuf::from("../utils/test-files"));
  442. let mut args = HashMap::new();
  443. args.insert("path".to_string(), to_value("test.csv").unwrap());
  444. let result = static_fn.call(&args.clone()).unwrap();
  445. assert_eq!(
  446. result,
  447. json!({
  448. "headers": ["Number", "Title"],
  449. "records": [
  450. ["1", "Gutenberg"],
  451. ["2", "Printing"]
  452. ],
  453. })
  454. )
  455. }
  456. // Test points to bad csv file with uneven row lengths
  457. #[test]
  458. fn bad_csv_should_result_in_error() {
  459. let static_fn = LoadData::new(PathBuf::from("../utils/test-files"));
  460. let mut args = HashMap::new();
  461. args.insert("path".to_string(), to_value("uneven_rows.csv").unwrap());
  462. let result = static_fn.call(&args.clone());
  463. assert!(result.is_err());
  464. let error_kind = result.err().unwrap().kind;
  465. match error_kind {
  466. tera::ErrorKind::Msg(msg) => {
  467. if msg != String::from("Error encountered when parsing csv records") {
  468. panic!("Error message is wrong. Perhaps wrong error is being returned?");
  469. }
  470. }
  471. _ => panic!("Error encountered was not expected CSV error"),
  472. }
  473. }
  474. #[test]
  475. fn can_load_json() {
  476. let static_fn = LoadData::new(PathBuf::from("../utils/test-files"));
  477. let mut args = HashMap::new();
  478. args.insert("path".to_string(), to_value("test.json").unwrap());
  479. let result = static_fn.call(&args.clone()).unwrap();
  480. assert_eq!(
  481. result,
  482. json!({
  483. "key": "value",
  484. "array": [1, 2, 3],
  485. "subpackage": {
  486. "subkey": 5
  487. }
  488. })
  489. )
  490. }
  491. }