You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

471 lines
15KB

  1. extern crate serde_json;
  2. extern crate toml;
  3. use utils::fs::{get_file_time, is_path_in_directory, read_file};
  4. use reqwest::{header, Client};
  5. use std::collections::hash_map::DefaultHasher;
  6. use std::fmt;
  7. use std::hash::{Hash, Hasher};
  8. use std::str::FromStr;
  9. use url::Url;
  10. use std::path::PathBuf;
  11. use std::sync::{Arc, Mutex};
  12. use csv::Reader;
  13. use std::collections::HashMap;
  14. use tera::{from_value, to_value, Error, GlobalFn, Map, Result, Value};
  15. static GET_DATA_ARGUMENT_ERROR_MESSAGE: &str =
  16. "`load_data`: requires EITHER a `path` or `url` argument";
  17. enum DataSource {
  18. Url(Url),
  19. Path(PathBuf),
  20. }
  21. #[derive(Debug)]
  22. enum OutputFormat {
  23. Toml,
  24. Json,
  25. Csv,
  26. Plain,
  27. }
  28. impl fmt::Display for OutputFormat {
  29. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
  30. fmt::Debug::fmt(self, f)
  31. }
  32. }
  33. impl Hash for OutputFormat {
  34. fn hash<H: Hasher>(&self, state: &mut H) {
  35. self.to_string().hash(state);
  36. }
  37. }
  38. impl FromStr for OutputFormat {
  39. type Err = Error;
  40. fn from_str(output_format: &str) -> Result<Self> {
  41. return match output_format {
  42. "toml" => Ok(OutputFormat::Toml),
  43. "csv" => Ok(OutputFormat::Csv),
  44. "json" => Ok(OutputFormat::Json),
  45. "plain" => Ok(OutputFormat::Plain),
  46. format => Err(format!("Unknown output format {}", format).into()),
  47. };
  48. }
  49. }
  50. impl OutputFormat {
  51. fn as_accept_header(&self) -> header::HeaderValue {
  52. return header::HeaderValue::from_static(match self {
  53. OutputFormat::Json => "application/json",
  54. OutputFormat::Csv => "text/csv",
  55. OutputFormat::Toml => "application/toml",
  56. OutputFormat::Plain => "text/plain",
  57. });
  58. }
  59. }
  60. impl DataSource {
  61. fn from_args(
  62. path_arg: Option<String>,
  63. url_arg: Option<String>,
  64. content_path: &PathBuf,
  65. ) -> Result<Self> {
  66. if path_arg.is_some() && url_arg.is_some() {
  67. return Err(GET_DATA_ARGUMENT_ERROR_MESSAGE.into());
  68. }
  69. if let Some(path) = path_arg {
  70. let full_path = content_path.join(path);
  71. if !full_path.exists() {
  72. return Err(format!("{} doesn't exist", full_path.display()).into());
  73. }
  74. return Ok(DataSource::Path(full_path));
  75. }
  76. if let Some(url) = url_arg {
  77. return Url::parse(&url)
  78. .map(|parsed_url| DataSource::Url(parsed_url))
  79. .map_err(|e| format!("Failed to parse {} as url: {}", url, e).into());
  80. }
  81. return Err(GET_DATA_ARGUMENT_ERROR_MESSAGE.into());
  82. }
  83. fn get_cache_key(&self, format: &OutputFormat) -> u64 {
  84. let mut hasher = DefaultHasher::new();
  85. format.hash(&mut hasher);
  86. self.hash(&mut hasher);
  87. return hasher.finish();
  88. }
  89. }
  90. impl Hash for DataSource {
  91. fn hash<H: Hasher>(&self, state: &mut H) {
  92. match self {
  93. DataSource::Url(url) => url.hash(state),
  94. DataSource::Path(path) => {
  95. path.hash(state);
  96. get_file_time(&path).expect("get file time").hash(state);
  97. }
  98. };
  99. }
  100. }
  101. fn get_data_source_from_args(
  102. content_path: &PathBuf,
  103. args: &HashMap<String, Value>,
  104. ) -> Result<DataSource> {
  105. let path_arg = optional_arg!(String, args.get("path"), GET_DATA_ARGUMENT_ERROR_MESSAGE);
  106. let url_arg = optional_arg!(String, args.get("url"), GET_DATA_ARGUMENT_ERROR_MESSAGE);
  107. return DataSource::from_args(path_arg, url_arg, content_path);
  108. }
  109. fn read_data_file(base_path: &PathBuf, full_path: PathBuf) -> Result<String> {
  110. if !is_path_in_directory(&base_path, &full_path)
  111. .map_err(|e| format!("Failed to read data file {}: {}", full_path.display(), e))?
  112. {
  113. return Err(format!(
  114. "{} is not inside the base site directory {}",
  115. full_path.display(),
  116. base_path.display()
  117. )
  118. .into());
  119. }
  120. return read_file(&full_path).map_err(|e| {
  121. format!("`load_data`: error {} loading file {}", full_path.to_str().unwrap(), e).into()
  122. });
  123. }
  124. fn get_output_format_from_args(
  125. args: &HashMap<String, Value>,
  126. data_source: &DataSource,
  127. ) -> Result<OutputFormat> {
  128. let format_arg = optional_arg!(
  129. String,
  130. args.get("format"),
  131. "`load_data`: `format` needs to be an argument with a string value, being one of the supported `load_data` file types (csv, json, toml)"
  132. );
  133. if let Some(format) = format_arg {
  134. return OutputFormat::from_str(&format);
  135. }
  136. let from_extension = if let DataSource::Path(path) = data_source {
  137. let extension_result: Result<&str> =
  138. path.extension().map(|extension| extension.to_str().unwrap()).ok_or(
  139. format!("Could not determine format for {} from extension", path.display()).into(),
  140. );
  141. extension_result?
  142. } else {
  143. "plain"
  144. };
  145. return OutputFormat::from_str(from_extension);
  146. }
  147. /// A global function to load data from a file or from a URL
  148. /// Currently the supported formats are json, toml, csv and plain text
  149. pub fn make_load_data(content_path: PathBuf, base_path: PathBuf) -> GlobalFn {
  150. let mut headers = header::HeaderMap::new();
  151. headers.insert(header::USER_AGENT, "zola".parse().unwrap());
  152. let client = Arc::new(Mutex::new(Client::builder().build().expect("reqwest client build")));
  153. let result_cache: Arc<Mutex<HashMap<u64, Value>>> = Arc::new(Mutex::new(HashMap::new()));
  154. Box::new(move |args| -> Result<Value> {
  155. let data_source = get_data_source_from_args(&content_path, &args)?;
  156. let file_format = get_output_format_from_args(&args, &data_source)?;
  157. let cache_key = data_source.get_cache_key(&file_format);
  158. let mut cache = result_cache.lock().expect("result cache lock");
  159. let response_client = client.lock().expect("response client lock");
  160. if let Some(cached_result) = cache.get(&cache_key) {
  161. return Ok(cached_result.clone());
  162. }
  163. let data = match data_source {
  164. DataSource::Path(path) => read_data_file(&base_path, path),
  165. DataSource::Url(url) => {
  166. let mut response = response_client
  167. .get(url.as_str())
  168. .header(header::ACCEPT, file_format.as_accept_header())
  169. .send()
  170. .and_then(|res| res.error_for_status())
  171. .map_err(|e| {
  172. format!(
  173. "Failed to request {}: {}",
  174. url,
  175. e.status().expect("response status")
  176. )
  177. })?;
  178. response
  179. .text()
  180. .map_err(|e| format!("Failed to parse response from {}: {:?}", url, e).into())
  181. }
  182. }?;
  183. let result_value: Result<Value> = match file_format {
  184. OutputFormat::Toml => load_toml(data),
  185. OutputFormat::Csv => load_csv(data),
  186. OutputFormat::Json => load_json(data),
  187. OutputFormat::Plain => to_value(data).map_err(|e| e.into()),
  188. };
  189. if let Ok(data_result) = &result_value {
  190. cache.insert(cache_key, data_result.clone());
  191. }
  192. result_value
  193. })
  194. }
  195. /// Parse a JSON string and convert it to a Tera Value
  196. fn load_json(json_data: String) -> Result<Value> {
  197. let json_content: Value =
  198. serde_json::from_str(json_data.as_str()).map_err(|e| format!("{:?}", e))?;
  199. return Ok(json_content);
  200. }
  201. /// Parse a TOML string and convert it to a Tera Value
  202. fn load_toml(toml_data: String) -> Result<Value> {
  203. let toml_content: toml::Value = toml::from_str(&toml_data).map_err(|e| format!("{:?}", e))?;
  204. to_value(toml_content).map_err(|e| e.into())
  205. }
  206. /// Parse a CSV string and convert it to a Tera Value
  207. ///
  208. /// An example csv file `example.csv` could be:
  209. /// ```csv
  210. /// Number, Title
  211. /// 1,Gutenberg
  212. /// 2,Printing
  213. /// ```
  214. /// The json value output would be:
  215. /// ```json
  216. /// {
  217. /// "headers": ["Number", "Title"],
  218. /// "records": [
  219. /// ["1", "Gutenberg"],
  220. /// ["2", "Printing"]
  221. /// ],
  222. /// }
  223. /// ```
  224. fn load_csv(csv_data: String) -> Result<Value> {
  225. let mut reader = Reader::from_reader(csv_data.as_bytes());
  226. let mut csv_map = Map::new();
  227. {
  228. let hdrs = reader.headers().map_err(|e| {
  229. format!("'load_data': {} - unable to read CSV header line (line 1) for CSV file", e)
  230. })?;
  231. let headers_array = hdrs.iter().map(|v| Value::String(v.to_string())).collect();
  232. csv_map.insert(String::from("headers"), Value::Array(headers_array));
  233. }
  234. {
  235. let records = reader.records();
  236. let mut records_array: Vec<Value> = Vec::new();
  237. for result in records {
  238. let record = result.unwrap();
  239. let mut elements_array: Vec<Value> = Vec::new();
  240. for e in record.into_iter() {
  241. elements_array.push(Value::String(String::from(e)));
  242. }
  243. records_array.push(Value::Array(elements_array));
  244. }
  245. csv_map.insert(String::from("records"), Value::Array(records_array));
  246. }
  247. let csv_value: Value = Value::Object(csv_map);
  248. to_value(csv_value).map_err(|err| err.into())
  249. }
  250. #[cfg(test)]
  251. mod tests {
  252. use super::{make_load_data, DataSource, OutputFormat};
  253. use std::collections::HashMap;
  254. use std::path::PathBuf;
  255. use tera::to_value;
  256. fn get_test_file(filename: &str) -> PathBuf {
  257. let test_files = PathBuf::from("../utils/test-files").canonicalize().unwrap();
  258. return test_files.join(filename);
  259. }
  260. #[test]
  261. fn fails_when_missing_file() {
  262. let static_fn =
  263. make_load_data(PathBuf::from("../utils/test-files"), PathBuf::from("../utils"));
  264. let mut args = HashMap::new();
  265. args.insert("path".to_string(), to_value("../../../READMEE.md").unwrap());
  266. let result = static_fn(args);
  267. assert!(result.is_err());
  268. assert!(result.unwrap_err().description().contains("READMEE.md doesn't exist"));
  269. }
  270. #[test]
  271. fn cant_load_outside_content_dir() {
  272. let static_fn =
  273. make_load_data(PathBuf::from("../utils/test-files"), PathBuf::from("../utils"));
  274. let mut args = HashMap::new();
  275. args.insert("path".to_string(), to_value("../../../README.md").unwrap());
  276. args.insert("format".to_string(), to_value("plain").unwrap());
  277. let result = static_fn(args);
  278. assert!(result.is_err());
  279. assert!(
  280. result
  281. .unwrap_err()
  282. .description()
  283. .contains("README.md is not inside the base site directory")
  284. );
  285. }
  286. #[test]
  287. fn calculates_cache_key_for_path() {
  288. // We can't test against a fixed value, due to the fact the cache key is built from the absolute path
  289. let cache_key =
  290. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Toml);
  291. let cache_key_2 =
  292. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Toml);
  293. assert_eq!(cache_key, cache_key_2);
  294. }
  295. #[test]
  296. fn calculates_cache_key_for_url() {
  297. let cache_key =
  298. DataSource::Url("https://api.github.com/repos/getzola/zola".parse().unwrap())
  299. .get_cache_key(&OutputFormat::Plain);
  300. assert_eq!(cache_key, 8916756616423791754);
  301. }
  302. #[test]
  303. fn different_cache_key_per_filename() {
  304. let toml_cache_key =
  305. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Toml);
  306. let json_cache_key =
  307. DataSource::Path(get_test_file("test.json")).get_cache_key(&OutputFormat::Toml);
  308. assert_ne!(toml_cache_key, json_cache_key);
  309. }
  310. #[test]
  311. fn different_cache_key_per_format() {
  312. let toml_cache_key =
  313. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Toml);
  314. let json_cache_key =
  315. DataSource::Path(get_test_file("test.toml")).get_cache_key(&OutputFormat::Json);
  316. assert_ne!(toml_cache_key, json_cache_key);
  317. }
  318. #[test]
  319. fn can_load_remote_data() {
  320. let static_fn = make_load_data(PathBuf::new(), PathBuf::new());
  321. let mut args = HashMap::new();
  322. args.insert("url".to_string(), to_value("https://httpbin.org/json").unwrap());
  323. args.insert("format".to_string(), to_value("json").unwrap());
  324. let result = static_fn(args).unwrap();
  325. assert_eq!(
  326. result.get("slideshow").unwrap().get("title").unwrap(),
  327. &to_value("Sample Slide Show").unwrap()
  328. );
  329. }
  330. #[test]
  331. fn fails_when_request_404s() {
  332. let static_fn = make_load_data(PathBuf::new(), PathBuf::new());
  333. let mut args = HashMap::new();
  334. args.insert("url".to_string(), to_value("https://httpbin.org/status/404/").unwrap());
  335. args.insert("format".to_string(), to_value("json").unwrap());
  336. let result = static_fn(args);
  337. assert!(result.is_err());
  338. assert_eq!(
  339. result.unwrap_err().description(),
  340. "Failed to request https://httpbin.org/status/404/: 404 Not Found"
  341. );
  342. }
  343. #[test]
  344. fn can_load_toml() {
  345. let static_fn = make_load_data(
  346. PathBuf::from("../utils/test-files"),
  347. PathBuf::from("../utils/test-files"),
  348. );
  349. let mut args = HashMap::new();
  350. args.insert("path".to_string(), to_value("test.toml").unwrap());
  351. let result = static_fn(args.clone()).unwrap();
  352. //TOML does not load in order, and also dates are not returned as strings, but
  353. //rather as another object with a key and value
  354. assert_eq!(
  355. result,
  356. json!({
  357. "category": {
  358. "date": {
  359. "$__toml_private_datetime": "1979-05-27T07:32:00Z"
  360. },
  361. "key": "value"
  362. },
  363. })
  364. );
  365. }
  366. #[test]
  367. fn can_load_csv() {
  368. let static_fn = make_load_data(
  369. PathBuf::from("../utils/test-files"),
  370. PathBuf::from("../utils/test-files"),
  371. );
  372. let mut args = HashMap::new();
  373. args.insert("path".to_string(), to_value("test.csv").unwrap());
  374. let result = static_fn(args.clone()).unwrap();
  375. assert_eq!(
  376. result,
  377. json!({
  378. "headers": ["Number", "Title"],
  379. "records": [
  380. ["1", "Gutenberg"],
  381. ["2", "Printing"]
  382. ],
  383. })
  384. )
  385. }
  386. #[test]
  387. fn can_load_json() {
  388. let static_fn = make_load_data(
  389. PathBuf::from("../utils/test-files"),
  390. PathBuf::from("../utils/test-files"),
  391. );
  392. let mut args = HashMap::new();
  393. args.insert("path".to_string(), to_value("test.json").unwrap());
  394. let result = static_fn(args.clone()).unwrap();
  395. assert_eq!(
  396. result,
  397. json!({
  398. "key": "value",
  399. "array": [1, 2, 3],
  400. "subpackage": {
  401. "subkey": 5
  402. }
  403. })
  404. )
  405. }
  406. }