commit 4ed19f9198a2d4e1518328b9462a58c8dd8a3525 Author: Maurice Date: Thu Aug 15 22:09:53 2024 +0200 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..c87f2e8 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,380 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cpufeatures" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dh-tree-example" +version = "0.1.0" +dependencies = [ + "hex", + "hkdf", + "num-bigint", + "rand", + "sha2", + "x25519-dalek", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] +name = "libc" +version = "0.2.155" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + +[[package]] +name = "serde" +version = "1.0.207" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5665e14a49a4ea1b91029ba7d3bca9f299e1f7cfa194388ccc20f14743e784f2" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.207" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6aea2634c86b0e8ef2cfdc0c340baede54ec27b1e46febd7f80dffb2aa44a00e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fceb41e3d546d0bd83421d3409b1460cc7444cd389341a4c880fe7a042cb3d7" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "x25519-dalek" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core", + "serde", + "zeroize", +] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..3cb5f2c --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "dh-tree-example" +version = "0.1.0" +edition = "2021" + +[dependencies] +hex = "0.4.3" +hkdf = "0.12.4" +num-bigint = "0.4.6" +rand = "0.8.5" +sha2 = "0.10.8" +x25519-dalek = { version = "2", features = ["static_secrets"]} diff --git a/src/dh_node.rs b/src/dh_node.rs new file mode 100644 index 0000000..962c604 --- /dev/null +++ b/src/dh_node.rs @@ -0,0 +1,128 @@ +use std::borrow::Borrow; + +use hkdf::Hkdf; +use num_bigint::BigInt; +use sha2::Sha256; +use x25519_dalek::{PublicKey, StaticSecret}; + +use crate::tree_error::RatchetTreeError; + +#[derive(Clone)] +pub struct DhNode { + pub pub_key: PublicKey, + pub secret_key: Option, + pub user_id: Option, +} + +impl DhNode { + pub fn new(user_id: Option) -> Self { + let secret_key = StaticSecret::random_from_rng(&mut rand::thread_rng()); + let pub_key = PublicKey::from(&secret_key); + + Self { + pub_key, + secret_key: Some(secret_key), + user_id, + } + } + + pub fn import(user_id: Option, secret_key: Option<&str>, public_key: &str) -> Self { + let secret_key = secret_key + .map(hex::decode) + .map(Result::unwrap) + .map(TryInto::try_into) + .map(Result::unwrap) + .map(|b: [u8; 32]| StaticSecret::from(b)); + + let pub_key = TryInto::<[u8; 32]>::try_into(hex::decode(public_key).unwrap()) + .map(|b| PublicKey::from(b)) + .unwrap(); + + Self { + pub_key, + secret_key, + user_id, + } + } +} + +impl DhNode { + /// Derive parent node from this node and its sibling + pub fn derive( + &self, + sibling: &DhNode, + chain_key: Option<&[u8; 32]>, + ) -> Result { + let key = self.derive_key(sibling, chain_key.map(|k| &k[..]))?; + let secret_key = StaticSecret::from(key); + let pub_key = PublicKey::from(&secret_key); + + Ok(Self { + pub_key, + secret_key: Some(secret_key), + user_id: None, + }) + } + + /// secret_key = hkdf(salt: chain_key, info: sum_big_endian(self.pub_key, other.pub_key), ikm: shared_secret(self.priv_key, other.pub_key)) + pub fn derive_key( + &self, + other: &DhNode, + salt: Option<&[u8]>, + ) -> Result<[u8; 32], RatchetTreeError> { + // Determine which party has a private key + let (secret_key, public_key) = if let Some(secret_key) = self.secret_key.borrow() { + (secret_key, &other.pub_key) + } else if let Some(secret_key) = other.secret_key.borrow() { + (secret_key, &self.pub_key) + } else { + return Err(RatchetTreeError::MissingPrivateKey); + }; + + let shared_secret = secret_key.diffie_hellman(public_key); + let hkdf = Hkdf::::new(salt, shared_secret.as_bytes()); + + let pub_key_sum = BigInt::from_signed_bytes_be(self.pub_key.as_bytes()) + + BigInt::from_signed_bytes_be(other.pub_key.as_bytes()); + + let mut result = [0u8; 32]; + hkdf.expand(&pub_key_sum.to_signed_bytes_be(), &mut result) + .expect("Failed to expand bytes"); + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn can_derive_same_key() { + let pair1 = DhNode::new(None); + let pair2 = DhNode::new(None); + + let pair12 = pair1.derive(&pair2, None).unwrap(); + let pair21 = pair2.derive(&pair1, None).unwrap(); + + assert_eq!( + pair12.secret_key.unwrap().to_bytes(), + pair21.secret_key.unwrap().to_bytes() + ); + } + + #[test] + pub fn derivation_with_chain_key_yields_other_result() { + let pair1 = DhNode::new(None); + let pair2 = DhNode::new(None); + let chain_key = &[1u8; 32]; + + let pair12 = pair1.derive(&pair2, None).unwrap(); + let pair12b = pair2.derive(&pair1, Some(chain_key)).unwrap(); + + assert_ne!( + pair12.secret_key.unwrap().to_bytes(), + pair12b.secret_key.unwrap().to_bytes() + ); + } +} diff --git a/src/dh_tree.rs b/src/dh_tree.rs new file mode 100644 index 0000000..fd2ed1a --- /dev/null +++ b/src/dh_tree.rs @@ -0,0 +1,326 @@ +use std::{cell::RefCell, collections::VecDeque}; + +use hkdf::Hkdf; +use rand::RngCore; +use sha2::Sha256; +use x25519_dalek::{PublicKey, StaticSecret}; + +use crate::{ + dh_node::DhNode, + full_binary_tree::{FullBinaryTree, TreeIteratorMode}, + tree_error::RatchetTreeError, +}; + +#[derive(Clone)] +pub struct RatchetKeyTree { + pub tree: FullBinaryTree, + pub prev_tree_secret: Option<[u8; 32]>, + pub ratchet_key: RefCell>, +} + +pub struct RatchetTreeKeyUpdate { + pub user_id: Option, + pub index: usize, + pub prev_index: Option, + pub new_keys: VecDeque<[u8; 32]>, +} + +// TODO: scan for nodes to remove +// TODO: key update replace other node, set other to remove +impl RatchetKeyTree { + pub fn new() -> Self { + Self { + tree: FullBinaryTree::new(), + prev_tree_secret: None, + ratchet_key: RefCell::new(None), + } + } + + pub fn create_key_update(&self, index: usize) -> (RatchetTreeKeyUpdate, StaticSecret) { + let mut simulation = self.clone(); + let (old_location, new_location) = simulation.tree.find_best_update_location(index); + let user_id = simulation.tree.nodes[index].value.user_id; + + let new_node: &mut crate::tree_node::TreeNode = + &mut simulation.tree.nodes[new_location]; + + let new_data = DhNode::new(user_id); + let new_key = new_data.secret_key.clone().unwrap(); + new_node.value = new_data; + + simulation + .update(new_location) + .expect("Failed to create update on simulation"); + + let new_keys: VecDeque<[u8; 32]> = simulation + .tree + .iter(TreeIteratorMode::Parent(new_location)) + .map(|n| n.value.pub_key.as_bytes().clone()) + .collect(); + + ( + RatchetTreeKeyUpdate { + new_keys, + user_id, + index: new_location, + prev_index: old_location, + }, + new_key, + ) + } + + pub fn apply_key_update( + &mut self, + update: RatchetTreeKeyUpdate, + ) -> Result<(), RatchetTreeError> { + let current_item = self + .tree + .nodes + .get(update.index) + .ok_or(RatchetTreeError::NodeNotFound(update.index))?; + + if !current_item.is_expired && current_item.value.user_id != update.user_id { + return Err(RatchetTreeError::NotAllowedToReplaceNode); + } + + if let Some(prev_index) = update.prev_index { + // Try to cleanup if cleanup is possible (you are neighbour of expired node and last or second last, + // or last two nodes are both expired) + if self.tree.suitable_for_clean_up(prev_index) { + if let Some(index) = self.tree.remove_last_expired_leaf() { + if !index == update.index { + return Err(RatchetTreeError::NotAllowedToDeleteNode); + } + }; + } + + let prev_node = &mut self.tree.nodes[prev_index]; + prev_node.is_expired = true; + prev_node.value.user_id = None; + prev_node.value.secret_key = None; + } + + let mut next_index = Some(update.index); + let mut keys = update.new_keys; + + while let Some(index) = next_index { + let node = self + .tree + .nodes + .get_mut(index) + .ok_or(RatchetTreeError::NodeNotFound(index))?; + + let new_key = keys + .pop_front() + .ok_or(RatchetTreeError::NotEnoughKeysToUpdate)?; + + node.value.pub_key = PublicKey::from(new_key); + node.value.secret_key = None; + node.is_expired = false; + + node.value.user_id = if index == update.index { + update.user_id + } else { + None + }; + + next_index = self.tree.parent_index(index); + } + + Ok(()) + } + + pub fn add(&mut self, new_node: DhNode) -> Result<(), RatchetTreeError> { + let index = self.tree.add(new_node); + self.update(index)?; + Ok(()) + } + + fn update(&mut self, index: usize) -> Result<(), RatchetTreeError> { + let prev_secret = self.tree.nodes[0] + .value + .secret_key + .clone() + .map(|s| s.to_bytes()) + .expect("No secret key present"); + + let mut self_idx = index; + + while let Some(parent_idx) = self.tree.parent_index(self_idx) { + let me = self + .tree + .nodes + .get(self_idx) + .ok_or(RatchetTreeError::NodeNotFound(self_idx))?; + + let sibling_idx = self + .tree + .sibling_index(self_idx) + .ok_or(RatchetTreeError::SiblingNodeNotFound(self_idx))?; + + let sibling = self + .tree + .nodes + .get(sibling_idx) + .ok_or(RatchetTreeError::SiblingNodeNotFound(self_idx))?; + + let (left, right) = if me.value.secret_key.is_some() { + (me, sibling) + } else if sibling.value.secret_key.is_some() { + (sibling, me) + } else { + return Err(RatchetTreeError::MissingPrivateKey); + }; + + let replace_node = left.value.derive( + &right.value, + if parent_idx == 0 { + self.prev_tree_secret.as_ref() + } else { + None + }, + )?; + + let parent = self + .tree + .nodes + .get_mut(parent_idx) + .ok_or(RatchetTreeError::NodeNotFound(parent_idx))?; + + parent.value = replace_node; + self_idx = parent_idx; + } + + let new_secret = self.tree.nodes[0] + .value + .secret_key + .clone() + .map(|s| s.to_bytes()) + .expect("No secret key present"); + + if !prev_secret.eq(&new_secret) { + self.prev_tree_secret = Some(prev_secret); + + println!( + "Updated secret: {:x?}, Previous: {:x?}", + new_secret, prev_secret + ); + } + + Ok(()) + } + + // body_encryption_key = hkdf(ikm: dh(tree.priv_key, node2.pub), salt: ratchet_id, info: tree.pub_key) + fn get_body_encryption_key(&self, leaf_index: usize) -> Result<[u8; 32], RatchetTreeError> { + let child = self + .tree + .leaves() + .get(leaf_index) + .ok_or(RatchetTreeError::NodeNotFound(self.tree.index(leaf_index)))?; + let root = self + .tree + .nodes + .get(0) + .ok_or(RatchetTreeError::NodeNotFound(0))?; + + let ratchet_key = self.ratchet_key.borrow(); + let key = root + .value + .derive_key(&child.value, ratchet_key.as_ref().map(|k| &k[..]))?; + + Ok(key) + } + + // header_encryption_key = hkdf(ikm: tree.priv_key, salt: ratchet_id, info: tree.pub_key) + fn get_header_encryption_key(&self) -> Result<[u8; 32], RatchetTreeError> { + let root = self + .tree + .nodes + .get(0) + .ok_or(RatchetTreeError::NodeNotFound(0))?; + + if let Some(secret_key) = root.value.secret_key.as_ref() { + let ratchet_key: std::cell::Ref> = self.ratchet_key.borrow(); + let salt = ratchet_key.as_ref().map(|k| &k[..]); + let ikm = &secret_key.as_bytes()[..]; + + let hkdf = Hkdf::::new(salt, ikm); + let mut key = [0u8; 32]; + hkdf.expand(root.value.pub_key.as_bytes(), &mut key) + .expect("Failed to expand bytes"); + + Ok(key) + } else { + Err(RatchetTreeError::MissingPrivateKey) + } + } + + // public_encryption_key = hkdf(ikm: tree.pub_key, salt: random(), info: []) + fn get_public_encryption_key_and_salt(&self) -> Result<([u8; 32], [u8; 8]), RatchetTreeError> { + let root = self + .tree + .nodes + .get(0) + .ok_or(RatchetTreeError::NodeNotFound(0))?; + + let mut salt = [0u8; 8]; + let mut key = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut salt); + + let hkdf = Hkdf::::new(Some(&salt), &root.value.pub_key.as_bytes()[..]); + hkdf.expand(&[0u8; 0], &mut key) + .expect("Failed to expand bytes"); + + Ok((key, salt)) + } +} + +#[cfg(test)] +mod tests { + use super::RatchetKeyTree; + use crate::dh_node::DhNode; + + const PRIV_A: &str = "bde93e8e88997c7b130a827b86b5e3a33522c478a3a725b104be938d339de05d"; + const PRIV_B: &str = "80ce9bab99bd10ac0a6762e3c20a70d386eab16918741fa77d567dcfbe7ef2d5"; + const PRIV_C: &str = "ce6dbb15ddca4c5c3bf66436ecf1eef7e34716e651786adcc74ce4cf5d1f9fd2"; + const PRIV_D: &str = "e84a68f089bfb95e57c509711d05ef236c985940909e5b3940004400c928deb9"; + const PRIV_E: &str = "e24f4f02ea2270a0cc36bbb251440084b293edb4223feebcff0fab6bd58bd02a"; + + const PUB_A: &str = "1607af185dcf7c552cb8d37af5cdef29df26dde738a9da2d9f878e3c0d85e442"; + const PUB_B: &str = "636d16b0ab3f2c9d25a0aaf191415ab0817b82e51b3bcadeb325eaeac4b41802"; + const PUB_C: &str = "98588776857bb65085098455072f56d20913e56783e7a453421cf3d88d514261"; + const PUB_D: &str = "fa6e6eea0dd830a407aa184d5a56736965b19abcee7ccfec55b4099f8478860c"; + const PUB_E: &str = "1fe0fafc670b454663ebef01d2e4821960a3cf12591e726d4e44e6fbdbfc076c"; + + #[test] + fn can_create_a_tree() { + let mut rt = RatchetKeyTree::new(); + rt.add(DhNode::new(Some(123))) + .expect("First node to be able to add"); + assert_eq!(None, rt.prev_tree_secret); + assert_eq!(1, rt.tree.leaf_count()); + assert_eq!(1, rt.tree.nodes.len()); + assert_eq!(Some(123), rt.tree.nodes[0].value.user_id); + } + + #[test] + fn can_add_a_user() { + let mut rt = RatchetKeyTree::new(); + rt.add(DhNode::new(Some(123))).unwrap(); + rt.add(DhNode::new(Some(234))).expect("Failed to add user"); + assert_eq!(3, rt.tree.nodes.len()); + assert_eq!(2, rt.tree.leaf_count()); + } + + #[test] + fn can_create_a_valid_shared_secret() { + let node_a = DhNode::import(None, Some(PRIV_A), PUB_A); + let node_b = DhNode::import(None, Some(PRIV_B), PUB_B); + + let shared_secret = node_a.secret_key.unwrap().diffie_hellman(&node_b.pub_key); + assert_eq!( + "57d7c061f48818e811bd07662b263b45d081b4919e7ff54dceb05ca62f3dd47e", + hex::encode(shared_secret.as_bytes()) + ); + } +} diff --git a/src/full_binary_tree.rs b/src/full_binary_tree.rs new file mode 100644 index 0000000..9ac89a8 --- /dev/null +++ b/src/full_binary_tree.rs @@ -0,0 +1,477 @@ +use std::cmp; + +use crate::tree_node::TreeNode; + +/// Full binary tree stored in a Vec +#[derive(Clone)] +pub struct FullBinaryTree +where + T: Clone, +{ + /// Tree nodes are stored in array form, 'full' binary tree + /// + /// 0 + /// / \ + /// 1 2 == [0,1,2,3,4,5,6,7,8] + /// / \ / \ + /// 3 45 6 + /// / \ + /// 7 8 + /// So the node at [0] equals the top node and the last n nodes are the leaves. + pub nodes: Vec>, +} + +impl FullBinaryTree { + pub fn new() -> Self { + Self { nodes: Vec::new() } + } + + pub fn height(&self) -> usize { + // floor(log2(n)) + 1 + (self.nodes.len().ilog2() + 1) + .try_into() + .expect("Failed to parse u32s to usize") + } + + pub fn leaf_count(&self) -> usize { + // floor(n / 2) + n mod 2 + let node_count = self.nodes.len(); + node_count / 2 + node_count % 2 + } + + pub fn leaves(&self) -> &[TreeNode] { + &self.nodes[self.leaf_count() - 1..] + } + + pub fn leaves_mut(&mut self) -> &mut [TreeNode] { + let leaf_count = self.leaf_count(); + &mut self.nodes[leaf_count - 1..] + } + + pub fn index(&self, leaf_index: usize) -> usize { + leaf_index + (self.nodes.len() - self.leaf_count()) + } + + pub fn parent_index(&self, index: usize) -> Option { + if index == 0 { + None + } else { + // floor((index - 1) / 2) + Some((index - 1) / 2) + } + } + + pub fn sibling_index(&self, index: usize) -> Option { + if index == 0 { + return None; + } + + // (index & 1) == 1 ? index + 1 : max(0, index - 1) + let idx = if index & 1 == 1 { + index + 1 + } else { + cmp::max(0, index - 1) + }; + + if idx > self.nodes.len() - 1 { + None + } else { + Some(idx) + } + } + + pub fn left_child_index(&self, parent_index: usize) -> Option { + let left_index = 2 * parent_index + 1; + if left_index < self.nodes.len() { + Some(left_index) + } else { + None + } + } + + pub fn right_child_index(&self, parent_index: usize) -> Option { + let right_index = 2 * parent_index + 2; + if right_index < self.nodes.len() { + Some(right_index) + } else { + None + } + } + + pub fn first_expired_leaf_index(&self) -> Option { + self.leaves() + .iter() + .enumerate() + .filter(|(_, n)| n.is_expired) + .map(|(i, _)| self.index(i)) + .next() + } + + pub fn iter(&self, mode: TreeIteratorMode) -> TreeIterator { + TreeIterator { tree: &self, mode } + } + + pub fn add(&mut self, value: T) -> usize { + // If expired node exists, replace it + if self.nodes.len() > 0 { + let mut expired_nodes = self + .leaves_mut() + .iter_mut() + .enumerate() + .filter(|(_, n)| n.is_expired); + + if let Some((idx, node)) = expired_nodes.next() { + node.value = value; + node.is_expired = false; + return self.index(idx); + } + } + + let len = self.nodes.len(); + self.nodes.push(TreeNode { + value, + is_expired: false, + }); + + // Push parent node to balance tree + if len > 0 { + let idx = self.parent_index(len).unwrap(); + let parent_node = self.nodes[idx].clone(); + self.nodes.push(parent_node); + } + + len + } + + // Removes last or second last expired node from tree + // Returns index of where the other node is placed + pub fn remove_last_expired_leaf(&mut self) -> Option { + if self.nodes.len() <= 1 { + panic!("No reason to delete any node, just throw away your tree"); + } + + let index = self.index(self.leaf_count() - 1); + let sibling_index = self.sibling_index(index).expect("Expected sibling"); + let node_to_keep = match ( + self.nodes[index].is_expired, + self.nodes[sibling_index].is_expired, + ) { + (true, _) => sibling_index, + (false, true) => index, + (false, false) => return None, + }; + + let parent_index = self.parent_index(index).expect("Expected parents"); + self.nodes.swap(parent_index, node_to_keep); + self.nodes.truncate(self.nodes.len() - 2); + Some(parent_index) + } + + // Returns true if cleanup of last 2 nodes is possible + pub fn suitable_for_clean_up(&self, index: usize) -> bool { + if index >= self.nodes.len() - 2 { + let sibling_index = self.sibling_index(index).expect("Expected sibling"); + if self.nodes[sibling_index].is_expired { + return true; + } + } else if self.nodes.iter().rev().take(2).all(|n| n.is_expired) { + return true; + } + + false + } + + // Find best location to move to if we want to update a node (housekeeping) + // Returns (old_location if changed, new_location) + pub fn find_best_update_location(&self, index: usize) -> (Option, usize) { + if self.suitable_for_clean_up(index) { + ( + // If last 2 nodes, and my sibling became expired, move up. + // If last 2 nodes are both expired, move up to last 2 expired nodes + Some(index), + self.parent_index(self.nodes.len() - 1) + .expect("Expected parent"), + ) + } else if matches!(self.first_expired_leaf_index(), Some(expired_index) if expired_index < index) + { + // If other child more to the left became expired, take that place + (Some(index), self.first_expired_leaf_index().unwrap()) + } else { + // Keep in place + (None, index) + } + } +} + +pub enum TreeIteratorMode { + Parent(usize), + Finished, +} + +pub struct TreeIterator<'a, T: Clone> { + tree: &'a FullBinaryTree, + mode: TreeIteratorMode, +} + +impl<'a, T: Clone> Iterator for TreeIterator<'a, T> { + type Item = &'a TreeNode; + + fn next(&mut self) -> Option { + match self.mode { + TreeIteratorMode::Parent(index) => { + let result = self.tree.nodes.get(index); + let parent = self.tree.parent_index(index); + self.mode = match parent { + Some(index) => TreeIteratorMode::Parent(index), + None => TreeIteratorMode::Finished, + }; + result + } + TreeIteratorMode::Finished => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_tree(node_count: i32) -> FullBinaryTree { + let mut nodes = Vec::new(); + for i in 0..node_count { + nodes.push(TreeNode { + value: i, + is_expired: false, + }) + } + FullBinaryTree:: { nodes } + } + + fn actual_leaves(tree: &FullBinaryTree) -> Vec { + let actual_leaves: Vec = tree.leaves().iter().map(|l| l.value).collect(); + actual_leaves + } + + fn print_tree(tree: &FullBinaryTree) { + println!( + "Tree: {:?}", + tree.nodes.iter().map(|f| f.value).collect::>() + ); + println!("Leaves: {:?}", actual_leaves(&tree)); + } + + #[test] + fn tree_height_can_be_determined() { + let tree = test_tree(7); + assert_eq!(3, tree.height()); + + assert_eq!(vec![3, 4, 5, 6], actual_leaves(&tree)); + + assert_eq!(2, test_tree(3).height()); + assert_eq!(4, test_tree(8).height()); + assert_eq!(4, test_tree(10).height()); + assert_eq!(4, test_tree(15).height()); + assert_eq!(5, test_tree(20).height()); + assert_eq!(5, test_tree(30).height()); + } + + #[test] + fn leaf_count_can_be_determined() { + assert_eq!(2, test_tree(3).leaf_count()); + assert_eq!(3, test_tree(6).leaf_count()); + assert_eq!(4, test_tree(7).leaf_count()); + assert_eq!(8, test_tree(15).leaf_count()); + } + + #[test] + fn index_can_be_determined_from_leaf_index() { + assert_eq!(1, test_tree(3).index(0)); + assert_eq!(2, test_tree(3).index(1)); + assert_eq!(4, test_tree(7).index(1)); + assert_eq!(5, test_tree(7).index(2)); + } + + #[test] + fn parent_index_can_be_determined() { + assert_eq!(None, test_tree(7).parent_index(0)); + assert_eq!(Some(0), test_tree(7).parent_index(2)); + assert_eq!(Some(2), test_tree(7).parent_index(5)); + assert_eq!(Some(2), test_tree(7).parent_index(5)); + assert_eq!(Some(3), test_tree(9).parent_index(8)); + } + + #[test] + fn sibling_index_can_be_determined() { + assert_eq!(None, test_tree(10).sibling_index(9)); + assert_eq!(None, test_tree(10).sibling_index(0)); + assert_eq!(Some(7), test_tree(10).sibling_index(8)); + assert_eq!(Some(8), test_tree(10).sibling_index(7)); + assert_eq!(Some(6), test_tree(10).sibling_index(5)); + } + + #[test] + fn child_indexes_can_be_determined() { + assert_eq!(Some(1), test_tree(10).left_child_index(0)); + assert_eq!(Some(2), test_tree(10).right_child_index(0)); + assert_eq!(Some(7), test_tree(10).left_child_index(3)); + assert_eq!(Some(8), test_tree(10).right_child_index(3)); + assert_eq!(Some(9), test_tree(10).left_child_index(4)); + assert_eq!(None, test_tree(10).right_child_index(4)); + } + + #[test] + fn add_adds_node_and_balances_tree() { + /* + [0,1,2] + + 0 + 1 2 + + when 3 is added in a non-full binary tree that would be: [0,1,2,3] + 0 + 1 2 + 3 + + What actually happens is [0,1,2,3,1] + Results in + 0 + 1 2 + 3 1 + + Adding 4 results in + [0,1,2,3,1,4,2] + 0 + 1 2 + 3 1 4 2 + + Adding 5 results in + [0,1,2,3,1,4,2,5,3] + 0 + 1 2 + 3 1 4 2 + 5 3 + + */ + let mut tree = test_tree(3); + + let idx = tree.add(3); + assert_eq!(vec![2, 3, 1], actual_leaves(&tree)); + assert_eq!(3, tree.nodes[idx].value); + + let idx = tree.add(4); + assert_eq!(vec![3, 1, 4, 2], actual_leaves(&tree)); + assert_eq!(4, tree.nodes[idx].value); + + let idx = tree.add(5); + assert_eq!(vec![1, 4, 2, 5, 3], actual_leaves(&tree)); + assert_eq!(5, tree.nodes[idx].value); + } + + #[test] + fn is_able_to_remove_last_leaf() { + /* + Deleting nodes is complex + Given the following tree: + 0 + 1 2 + 3 4 5 6 + 7 8 + + Only last leaves can be removed + When 7 is removed, replace parent with 8 + Then remove last 2 items from array + This is ofcourse only possible if the sibling can do a key update immediately + + 0 + 1 2 + 8 4 5 6 + The next node to be able to removed is 6 or 5 + If we remove 6 + + 0 + 1 5 + 8 4 + */ + + let mut tree = test_tree(9); + assert_eq!(vec![4, 5, 6, 7, 8], actual_leaves(&tree)); + + tree.nodes[7].is_expired = true; // remove second last + assert_eq!(Some(3), tree.remove_last_expired_leaf()); + assert_eq!(tree.nodes[3].value, 8); + assert_eq!(vec![8, 4, 5, 6], actual_leaves(&tree)); + + tree.nodes[6].is_expired = true; // remove last + assert_eq!(Some(2), tree.remove_last_expired_leaf()); + assert_eq!(tree.nodes[2].value, 5); + assert_eq!(vec![5, 8, 4], actual_leaves(&tree)); + } + + #[test] + fn does_not_remove_expired_leaf_is_not_last_or_second_last() { + let mut tree = test_tree(7); + tree.nodes[3].is_expired = true; + tree.nodes[4].is_expired = true; + assert_eq!(None, tree.remove_last_expired_leaf()); + assert_eq!(vec![3, 4, 5, 6], actual_leaves(&tree)); + } + + #[test] + fn eventually_leaves_one_node_in_tree_if_nearly_all_leafs_are_expired() { + let mut tree = test_tree(7); + tree.nodes[3].is_expired = true; + tree.nodes[4].is_expired = false; + tree.nodes[5].is_expired = true; + tree.nodes[6].is_expired = true; + + assert_eq!(Some(2), tree.remove_last_expired_leaf()); + assert_eq!(vec![5, 3, 4], actual_leaves(&tree)); + + assert_eq!(Some(1), tree.remove_last_expired_leaf()); + assert_eq!(vec![4, 5], actual_leaves(&tree)); + + assert_eq!(Some(0), tree.remove_last_expired_leaf()); + assert_eq!(vec![4], actual_leaves(&tree)); + } + + #[test] + fn suitable_for_cleanup_returns_correct_value() { + let mut example = test_tree(7); + assert_eq!(false, example.suitable_for_clean_up(6)); + + example.nodes[5].is_expired = true; + assert_eq!(false, example.suitable_for_clean_up(5)); + assert_eq!(true, example.suitable_for_clean_up(6)); + + example.nodes[4].is_expired = true; + assert_eq!(false, example.suitable_for_clean_up(3)); + + example.nodes[6].is_expired = true; + assert_eq!(true, example.suitable_for_clean_up(3)); + } + + #[test] + fn is_able_to_find_best_update_location() { + /* + 0 + 1 2 + 3x 4 5x 6 + */ + + // Case 1: no node to replace + let mut example = test_tree(7); + assert_eq!((None, 6), example.find_best_update_location(6)); + + // Case 2: last or second last node, sibling is expired + example.nodes[5].is_expired = true; + assert_eq!((Some(6), 2), example.find_best_update_location(6)); + + // Case 3: node more to the left is expired + example.nodes[3].is_expired = true; + assert_eq!((Some(4), 3), example.find_best_update_location(4)); + + // Case 4: 2 last nodes are both expired (cleanup) + example.nodes[6].is_expired = true; + assert_eq!((Some(4), 2), example.find_best_update_location(4)); + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..1c705e0 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,7 @@ +mod dh_node; +mod dh_tree; +mod full_binary_tree; +mod tree_error; +mod tree_node; + +fn main() {} diff --git a/src/tree_error.rs b/src/tree_error.rs new file mode 100644 index 0000000..8e23793 --- /dev/null +++ b/src/tree_error.rs @@ -0,0 +1,9 @@ +#[derive(Debug)] +pub enum RatchetTreeError { + NodeNotFound(usize), + SiblingNodeNotFound(usize), + NotAllowedToReplaceNode, + NotAllowedToDeleteNode, + NotEnoughKeysToUpdate, + MissingPrivateKey, +} diff --git a/src/tree_node.rs b/src/tree_node.rs new file mode 100644 index 0000000..e992868 --- /dev/null +++ b/src/tree_node.rs @@ -0,0 +1,8 @@ +#[derive(Clone)] +pub struct TreeNode +where + T: Clone, +{ + pub value: T, + pub is_expired: bool, +}