hacker-challenge-sgx-general/src/persistence.rs

376 lines
12 KiB
Rust

#![allow(dead_code)]
use ed25519_dalek::SigningKey;
use ed25519_dalek::KEYPAIR_LENGTH;
use std::net::AddrParseError;
use std::net::Ipv4Addr;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom};
use tokio::sync::Mutex;
const DATA_SIZE: usize = 76;
enum Error {
CorruptedIP,
}
impl From<AddrParseError> for Error {
fn from(_: AddrParseError) -> Self {
Error::CorruptedIP
}
}
#[derive(Clone)]
struct Node {
ip: Ipv4Addr,
keypair: SigningKey,
joined_at: SystemTime,
}
impl TryFrom<(&str, SigningKey, SystemTime)> for Node {
type Error = Error;
fn try_from(value: (&str, SigningKey, SystemTime)) -> Result<Self, Self::Error> {
Ok(Self {
ip: value.0.parse()?,
keypair: value.1,
joined_at: value.2,
})
}
}
impl Node {
fn ip_as_string(&self) -> String {
self.ip.to_string()
}
fn signing_key(&self) -> SigningKey {
self.keypair.clone()
}
fn to_bytes(self) -> [u8; DATA_SIZE] {
let mut result = [0; DATA_SIZE];
result[0..4].copy_from_slice(&self.ip.octets());
result[4..68].copy_from_slice(&self.keypair.to_keypair_bytes());
result[68..DATA_SIZE].copy_from_slice(
&self
.joined_at
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
.to_le_bytes(),
);
result
}
fn from_bytes(bytes: [u8; DATA_SIZE]) -> Self {
let ip: [u8; 4] = bytes[0..4].try_into().unwrap();
let ip: Ipv4Addr = ip.into();
let keypair: [u8; KEYPAIR_LENGTH] = bytes[4..68].try_into().unwrap();
let keypair: SigningKey = SigningKey::from_keypair_bytes(&keypair).unwrap();
let joined_at: [u8; 8] = bytes[68..DATA_SIZE].try_into().unwrap();
let joined_at: u64 = u64::from_le_bytes(joined_at);
let joined_at = SystemTime::UNIX_EPOCH + Duration::from_secs(joined_at);
Self {
ip,
keypair,
joined_at,
}
}
}
struct FileManager {
file: Mutex<File>,
}
impl FileManager {
async fn init(path: &str) -> std::io::Result<Self> {
let file = File::options().read(true).append(true).open(path).await?;
Ok(Self {
file: Mutex::new(file),
})
}
async fn append_node(&self, node: Node) -> std::io::Result<()> {
let mut file = self.file.lock().await;
file.seek(SeekFrom::End(0)).await?;
file.write_all(&node.to_bytes()).await?;
file.flush().await?;
Ok(())
}
async fn get_node_by_id(&self, id: u64) -> std::io::Result<Node> {
let mut file = self.file.lock().await;
file.seek(SeekFrom::Start(
id.wrapping_mul(DATA_SIZE.try_into().unwrap_or(0)),
))
.await?;
let mut node_bytes = [0; DATA_SIZE];
file.read_exact(&mut node_bytes).await?;
Ok(Node::from_bytes(node_bytes))
}
/// Returns 20 nodes from the disk.
/// Specify offset (the number of nodes to skip).
async fn get_page_of_20(&self, offset: u64) -> std::io::Result<Vec<Node>> {
let mut file = self.file.lock().await;
file.seek(SeekFrom::Start(
offset
.wrapping_mul(DATA_SIZE.try_into().unwrap_or(0)),
))
.await?;
let mut nodes = Vec::new();
let mut count = 0;
loop {
let mut node_bytes = [0; DATA_SIZE];
if let Err(_) = file.read_exact(&mut node_bytes).await {
break;
};
nodes.push(Node::from_bytes(node_bytes));
count += 1;
if count == 20 {
break;
}
}
Ok(nodes)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ed25519_dalek::SigningKey;
use rand::rngs::OsRng;
use rand::Rng;
use std::io::Result;
use tokio::fs::remove_file;
use tokio::io::AsyncWriteExt;
const TEST_FILE_PREFIX: &str = ".tmp/test_";
fn get_test_file_name(function: &str) -> String {
TEST_FILE_PREFIX.to_string() + function
}
async fn setup_test_file(function: &str) -> Result<FileManager> {
let _ = tokio::fs::create_dir_all(".tmp").await;
let path = get_test_file_name(function);
let mut file = File::create(path.clone()).await?;
file.flush().await?;
drop(file);
FileManager::init(&path).await
}
#[test]
fn node_round_trip() {
let keypair = SigningKey::generate(&mut OsRng);
let original_node = Node {
ip: "192.168.1.1".parse().unwrap(),
keypair: keypair.clone(),
joined_at: SystemTime::now(),
};
let node_bytes = original_node.clone().to_bytes();
let restored_node = Node::from_bytes(node_bytes);
assert_eq!(original_node.ip_as_string(), restored_node.ip_as_string());
assert_eq!(
original_node.keypair.to_keypair_bytes(),
restored_node.keypair.to_keypair_bytes()
);
assert_eq!(
original_node
.joined_at
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
restored_node
.joined_at
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
);
}
#[tokio::test]
async fn setup_file_manager() {
let function_name = "setup_file_manager";
let _ = match setup_test_file(function_name).await {
Err(e) => {
panic!("Could not init File Manager: {}", e);
}
_ => remove_file(get_test_file_name(function_name)).await,
};
}
fn get_random_node() -> Node {
let keypair = SigningKey::generate(&mut OsRng);
let mut rng = rand::thread_rng();
let ipv4 = Ipv4Addr::new(rng.gen(), rng.gen(), rng.gen(), rng.gen());
Node {
ip: ipv4,
keypair: keypair.clone(),
joined_at: SystemTime::now(),
}
}
#[tokio::test]
async fn append_and_retrieve() -> Result<()> {
let function_name = "append_and_retrieve";
let manager = setup_test_file(function_name).await?;
let node = get_random_node();
manager.append_node(node.clone()).await?;
let retrieved_node = manager.get_node_by_id(0).await?;
assert_eq!(node.ip_as_string(), retrieved_node.ip_as_string());
assert_eq!(node.keypair, retrieved_node.keypair);
assert_eq!(
node.joined_at.duration_since(UNIX_EPOCH).unwrap().as_secs(),
retrieved_node
.joined_at
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
);
remove_file(get_test_file_name(function_name)).await?;
Ok(())
}
#[tokio::test]
async fn append_and_retrieve_multiple() -> Result<()> {
let function_name = "append_and_retrieve_multiple";
let manager = setup_test_file(function_name).await?;
let node1 = get_random_node();
let node2 = get_random_node();
manager.append_node(node1.clone()).await?;
manager.append_node(node2.clone()).await?;
let retrieved_node1 = manager.get_node_by_id(0).await?;
let node3 = get_random_node();
manager.append_node(node3.clone()).await.unwrap();
let retrieved_node2 = manager.get_node_by_id(1).await?;
assert_eq!(node1.ip_as_string(), retrieved_node1.ip_as_string());
assert_eq!(
node1.keypair.to_keypair_bytes(),
retrieved_node1.keypair.to_keypair_bytes()
);
assert_eq!(node2.ip_as_string(), retrieved_node2.ip_as_string());
assert_eq!(node2.keypair, retrieved_node2.keypair);
let retrieved_node3 = manager.get_node_by_id(2).await?;
assert_eq!(node3.ip_as_string(), retrieved_node3.ip_as_string());
assert_eq!(node3.keypair, retrieved_node3.keypair);
remove_file(get_test_file_name(function_name)).await?;
Ok(())
}
#[tokio::test]
async fn append_20_and_retrieve_1_loop() -> Result<()> {
let function_name = "append_20_and_retrieve_1_loop";
let manager = setup_test_file(function_name).await?;
let mut count = 0;
let mut nodes_vec: Vec<Node> = Vec::new();
while count < 100 {
let node = get_random_node();
if count % 10 == 0 {
nodes_vec.push(node.clone());
}
manager.append_node(node).await?;
count += 1;
}
count = 0;
for node in nodes_vec.iter() {
let r_node = manager.get_node_by_id(count * 10).await?;
assert_eq!(node.ip_as_string(), r_node.ip_as_string());
assert_eq!(node.keypair, r_node.keypair);
count += 1;
if count > 3 {
break;
}
}
count = 100;
while count < 500 {
let node = get_random_node();
if count % 10 == 0 {
nodes_vec.push(node.clone());
}
manager.append_node(node).await?;
count += 1;
}
count = 0;
for node in nodes_vec.iter() {
let r_node = manager.get_node_by_id(count * 10).await?;
assert_eq!(node.ip_as_string(), r_node.ip_as_string());
assert_eq!(node.keypair, r_node.keypair);
count += 1;
if count > 49 {
break;
}
}
remove_file(get_test_file_name(function_name)).await?;
Ok(())
}
#[tokio::test]
async fn get_page_of_20_nodes() -> Result<()> {
let function_name = "get_page_of_20_nodes";
let manager = setup_test_file(function_name).await?;
let mut count = 0;
let mut nodes: Vec<Node> = Vec::new();
while count < 100 {
let node = get_random_node();
if count >= 23 && count < 43 {
nodes.push(node.clone());
}
manager.append_node(node).await?;
count += 1;
}
count = 23;
let mut r_nodes = manager.get_page_of_20(23).await?.into_iter();
for node in nodes.iter() {
let r_node = r_nodes.next().unwrap();
println!("{} {} {}", count, node.ip_as_string(), r_node.ip_as_string());
assert_eq!(node.ip_as_string(), r_node.ip_as_string());
assert_eq!(node.keypair, r_node.keypair);
count += 1;
if count == 44 {
break;
}
}
remove_file(get_test_file_name(function_name)).await?;
Ok(())
}
#[tokio::test]
async fn get_last_page() -> Result<()> {
let function_name = "get_last_page";
let manager = setup_test_file(function_name).await?;
let mut count = 0;
let mut nodes: Vec<Node> = Vec::new();
while count < 97 {
let node = get_random_node();
if count >= 90 {
nodes.push(node.clone());
}
manager.append_node(node).await?;
count += 1;
}
count = 23;
let mut r_nodes = manager.get_page_of_20(90).await?.into_iter();
for node in nodes.iter() {
let r_node = r_nodes.next().unwrap();
println!("{} {} {}", count, node.ip_as_string(), r_node.ip_as_string());
assert_eq!(node.ip_as_string(), r_node.ip_as_string());
assert_eq!(node.keypair, r_node.keypair);
count += 1;
if count == 44 {
break;
}
}
remove_file(get_test_file_name(function_name)).await?;
Ok(())
}
}