Initial commit

This commit is contained in:
Maurice 2024-08-15 22:09:53 +02:00
commit 4ed19f9198
9 changed files with 1348 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

380
Cargo.lock generated Normal file
View File

@ -0,0 +1,380 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "autocfg"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cpufeatures"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad"
dependencies = [
"libc",
]
[[package]]
name = "crypto-common"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "curve25519-dalek"
version = "4.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be"
dependencies = [
"cfg-if",
"cpufeatures",
"curve25519-dalek-derive",
"fiat-crypto",
"rustc_version",
"subtle",
"zeroize",
]
[[package]]
name = "curve25519-dalek-derive"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "dh-tree-example"
version = "0.1.0"
dependencies = [
"hex",
"hkdf",
"num-bigint",
"rand",
"sha2",
"x25519-dalek",
]
[[package]]
name = "digest"
version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
"subtle",
]
[[package]]
name = "fiat-crypto"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "getrandom"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hkdf"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7"
dependencies = [
"hmac",
]
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "libc"
version = "0.2.155"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "ppv-lite86"
version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro2"
version = "1.0.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
]
[[package]]
name = "semver"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]]
name = "serde"
version = "1.0.207"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5665e14a49a4ea1b91029ba7d3bca9f299e1f7cfa194388ccc20f14743e784f2"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.207"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6aea2634c86b0e8ef2cfdc0c340baede54ec27b1e46febd7f80dffb2aa44a00e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "sha2"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "subtle"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "syn"
version = "2.0.74"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fceb41e3d546d0bd83421d3409b1460cc7444cd389341a4c880fe7a042cb3d7"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "version_check"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "x25519-dalek"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277"
dependencies = [
"curve25519-dalek",
"rand_core",
"serde",
"zeroize",
]
[[package]]
name = "zerocopy"
version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [
"byteorder",
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
dependencies = [
"zeroize_derive",
]
[[package]]
name = "zeroize_derive"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

12
Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "dh-tree-example"
version = "0.1.0"
edition = "2021"
[dependencies]
hex = "0.4.3"
hkdf = "0.12.4"
num-bigint = "0.4.6"
rand = "0.8.5"
sha2 = "0.10.8"
x25519-dalek = { version = "2", features = ["static_secrets"]}

128
src/dh_node.rs Normal file
View File

@ -0,0 +1,128 @@
use std::borrow::Borrow;
use hkdf::Hkdf;
use num_bigint::BigInt;
use sha2::Sha256;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::tree_error::RatchetTreeError;
#[derive(Clone)]
pub struct DhNode {
pub pub_key: PublicKey,
pub secret_key: Option<StaticSecret>,
pub user_id: Option<i32>,
}
impl DhNode {
pub fn new(user_id: Option<i32>) -> Self {
let secret_key = StaticSecret::random_from_rng(&mut rand::thread_rng());
let pub_key = PublicKey::from(&secret_key);
Self {
pub_key,
secret_key: Some(secret_key),
user_id,
}
}
pub fn import(user_id: Option<i32>, secret_key: Option<&str>, public_key: &str) -> Self {
let secret_key = secret_key
.map(hex::decode)
.map(Result::unwrap)
.map(TryInto::try_into)
.map(Result::unwrap)
.map(|b: [u8; 32]| StaticSecret::from(b));
let pub_key = TryInto::<[u8; 32]>::try_into(hex::decode(public_key).unwrap())
.map(|b| PublicKey::from(b))
.unwrap();
Self {
pub_key,
secret_key,
user_id,
}
}
}
impl DhNode {
/// Derive parent node from this node and its sibling
pub fn derive(
&self,
sibling: &DhNode,
chain_key: Option<&[u8; 32]>,
) -> Result<Self, RatchetTreeError> {
let key = self.derive_key(sibling, chain_key.map(|k| &k[..]))?;
let secret_key = StaticSecret::from(key);
let pub_key = PublicKey::from(&secret_key);
Ok(Self {
pub_key,
secret_key: Some(secret_key),
user_id: None,
})
}
/// secret_key = hkdf(salt: chain_key, info: sum_big_endian(self.pub_key, other.pub_key), ikm: shared_secret(self.priv_key, other.pub_key))
pub fn derive_key(
&self,
other: &DhNode,
salt: Option<&[u8]>,
) -> Result<[u8; 32], RatchetTreeError> {
// Determine which party has a private key
let (secret_key, public_key) = if let Some(secret_key) = self.secret_key.borrow() {
(secret_key, &other.pub_key)
} else if let Some(secret_key) = other.secret_key.borrow() {
(secret_key, &self.pub_key)
} else {
return Err(RatchetTreeError::MissingPrivateKey);
};
let shared_secret = secret_key.diffie_hellman(public_key);
let hkdf = Hkdf::<Sha256>::new(salt, shared_secret.as_bytes());
let pub_key_sum = BigInt::from_signed_bytes_be(self.pub_key.as_bytes())
+ BigInt::from_signed_bytes_be(other.pub_key.as_bytes());
let mut result = [0u8; 32];
hkdf.expand(&pub_key_sum.to_signed_bytes_be(), &mut result)
.expect("Failed to expand bytes");
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
pub fn can_derive_same_key() {
let pair1 = DhNode::new(None);
let pair2 = DhNode::new(None);
let pair12 = pair1.derive(&pair2, None).unwrap();
let pair21 = pair2.derive(&pair1, None).unwrap();
assert_eq!(
pair12.secret_key.unwrap().to_bytes(),
pair21.secret_key.unwrap().to_bytes()
);
}
#[test]
pub fn derivation_with_chain_key_yields_other_result() {
let pair1 = DhNode::new(None);
let pair2 = DhNode::new(None);
let chain_key = &[1u8; 32];
let pair12 = pair1.derive(&pair2, None).unwrap();
let pair12b = pair2.derive(&pair1, Some(chain_key)).unwrap();
assert_ne!(
pair12.secret_key.unwrap().to_bytes(),
pair12b.secret_key.unwrap().to_bytes()
);
}
}

326
src/dh_tree.rs Normal file
View File

@ -0,0 +1,326 @@
use std::{cell::RefCell, collections::VecDeque};
use hkdf::Hkdf;
use rand::RngCore;
use sha2::Sha256;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::{
dh_node::DhNode,
full_binary_tree::{FullBinaryTree, TreeIteratorMode},
tree_error::RatchetTreeError,
};
#[derive(Clone)]
pub struct RatchetKeyTree {
pub tree: FullBinaryTree<DhNode>,
pub prev_tree_secret: Option<[u8; 32]>,
pub ratchet_key: RefCell<Option<[u8; 32]>>,
}
pub struct RatchetTreeKeyUpdate {
pub user_id: Option<i32>,
pub index: usize,
pub prev_index: Option<usize>,
pub new_keys: VecDeque<[u8; 32]>,
}
// TODO: scan for nodes to remove
// TODO: key update replace other node, set other to remove
impl RatchetKeyTree {
pub fn new() -> Self {
Self {
tree: FullBinaryTree::new(),
prev_tree_secret: None,
ratchet_key: RefCell::new(None),
}
}
pub fn create_key_update(&self, index: usize) -> (RatchetTreeKeyUpdate, StaticSecret) {
let mut simulation = self.clone();
let (old_location, new_location) = simulation.tree.find_best_update_location(index);
let user_id = simulation.tree.nodes[index].value.user_id;
let new_node: &mut crate::tree_node::TreeNode<DhNode> =
&mut simulation.tree.nodes[new_location];
let new_data = DhNode::new(user_id);
let new_key = new_data.secret_key.clone().unwrap();
new_node.value = new_data;
simulation
.update(new_location)
.expect("Failed to create update on simulation");
let new_keys: VecDeque<[u8; 32]> = simulation
.tree
.iter(TreeIteratorMode::Parent(new_location))
.map(|n| n.value.pub_key.as_bytes().clone())
.collect();
(
RatchetTreeKeyUpdate {
new_keys,
user_id,
index: new_location,
prev_index: old_location,
},
new_key,
)
}
pub fn apply_key_update(
&mut self,
update: RatchetTreeKeyUpdate,
) -> Result<(), RatchetTreeError> {
let current_item = self
.tree
.nodes
.get(update.index)
.ok_or(RatchetTreeError::NodeNotFound(update.index))?;
if !current_item.is_expired && current_item.value.user_id != update.user_id {
return Err(RatchetTreeError::NotAllowedToReplaceNode);
}
if let Some(prev_index) = update.prev_index {
// Try to cleanup if cleanup is possible (you are neighbour of expired node and last or second last,
// or last two nodes are both expired)
if self.tree.suitable_for_clean_up(prev_index) {
if let Some(index) = self.tree.remove_last_expired_leaf() {
if !index == update.index {
return Err(RatchetTreeError::NotAllowedToDeleteNode);
}
};
}
let prev_node = &mut self.tree.nodes[prev_index];
prev_node.is_expired = true;
prev_node.value.user_id = None;
prev_node.value.secret_key = None;
}
let mut next_index = Some(update.index);
let mut keys = update.new_keys;
while let Some(index) = next_index {
let node = self
.tree
.nodes
.get_mut(index)
.ok_or(RatchetTreeError::NodeNotFound(index))?;
let new_key = keys
.pop_front()
.ok_or(RatchetTreeError::NotEnoughKeysToUpdate)?;
node.value.pub_key = PublicKey::from(new_key);
node.value.secret_key = None;
node.is_expired = false;
node.value.user_id = if index == update.index {
update.user_id
} else {
None
};
next_index = self.tree.parent_index(index);
}
Ok(())
}
pub fn add(&mut self, new_node: DhNode) -> Result<(), RatchetTreeError> {
let index = self.tree.add(new_node);
self.update(index)?;
Ok(())
}
fn update(&mut self, index: usize) -> Result<(), RatchetTreeError> {
let prev_secret = self.tree.nodes[0]
.value
.secret_key
.clone()
.map(|s| s.to_bytes())
.expect("No secret key present");
let mut self_idx = index;
while let Some(parent_idx) = self.tree.parent_index(self_idx) {
let me = self
.tree
.nodes
.get(self_idx)
.ok_or(RatchetTreeError::NodeNotFound(self_idx))?;
let sibling_idx = self
.tree
.sibling_index(self_idx)
.ok_or(RatchetTreeError::SiblingNodeNotFound(self_idx))?;
let sibling = self
.tree
.nodes
.get(sibling_idx)
.ok_or(RatchetTreeError::SiblingNodeNotFound(self_idx))?;
let (left, right) = if me.value.secret_key.is_some() {
(me, sibling)
} else if sibling.value.secret_key.is_some() {
(sibling, me)
} else {
return Err(RatchetTreeError::MissingPrivateKey);
};
let replace_node = left.value.derive(
&right.value,
if parent_idx == 0 {
self.prev_tree_secret.as_ref()
} else {
None
},
)?;
let parent = self
.tree
.nodes
.get_mut(parent_idx)
.ok_or(RatchetTreeError::NodeNotFound(parent_idx))?;
parent.value = replace_node;
self_idx = parent_idx;
}
let new_secret = self.tree.nodes[0]
.value
.secret_key
.clone()
.map(|s| s.to_bytes())
.expect("No secret key present");
if !prev_secret.eq(&new_secret) {
self.prev_tree_secret = Some(prev_secret);
println!(
"Updated secret: {:x?}, Previous: {:x?}",
new_secret, prev_secret
);
}
Ok(())
}
// body_encryption_key = hkdf(ikm: dh(tree.priv_key, node2.pub), salt: ratchet_id, info: tree.pub_key)
fn get_body_encryption_key(&self, leaf_index: usize) -> Result<[u8; 32], RatchetTreeError> {
let child = self
.tree
.leaves()
.get(leaf_index)
.ok_or(RatchetTreeError::NodeNotFound(self.tree.index(leaf_index)))?;
let root = self
.tree
.nodes
.get(0)
.ok_or(RatchetTreeError::NodeNotFound(0))?;
let ratchet_key = self.ratchet_key.borrow();
let key = root
.value
.derive_key(&child.value, ratchet_key.as_ref().map(|k| &k[..]))?;
Ok(key)
}
// header_encryption_key = hkdf(ikm: tree.priv_key, salt: ratchet_id, info: tree.pub_key)
fn get_header_encryption_key(&self) -> Result<[u8; 32], RatchetTreeError> {
let root = self
.tree
.nodes
.get(0)
.ok_or(RatchetTreeError::NodeNotFound(0))?;
if let Some(secret_key) = root.value.secret_key.as_ref() {
let ratchet_key: std::cell::Ref<Option<[u8; 32]>> = self.ratchet_key.borrow();
let salt = ratchet_key.as_ref().map(|k| &k[..]);
let ikm = &secret_key.as_bytes()[..];
let hkdf = Hkdf::<Sha256>::new(salt, ikm);
let mut key = [0u8; 32];
hkdf.expand(root.value.pub_key.as_bytes(), &mut key)
.expect("Failed to expand bytes");
Ok(key)
} else {
Err(RatchetTreeError::MissingPrivateKey)
}
}
// public_encryption_key = hkdf(ikm: tree.pub_key, salt: random(), info: [])
fn get_public_encryption_key_and_salt(&self) -> Result<([u8; 32], [u8; 8]), RatchetTreeError> {
let root = self
.tree
.nodes
.get(0)
.ok_or(RatchetTreeError::NodeNotFound(0))?;
let mut salt = [0u8; 8];
let mut key = [0u8; 32];
rand::thread_rng().fill_bytes(&mut salt);
let hkdf = Hkdf::<Sha256>::new(Some(&salt), &root.value.pub_key.as_bytes()[..]);
hkdf.expand(&[0u8; 0], &mut key)
.expect("Failed to expand bytes");
Ok((key, salt))
}
}
#[cfg(test)]
mod tests {
use super::RatchetKeyTree;
use crate::dh_node::DhNode;
const PRIV_A: &str = "bde93e8e88997c7b130a827b86b5e3a33522c478a3a725b104be938d339de05d";
const PRIV_B: &str = "80ce9bab99bd10ac0a6762e3c20a70d386eab16918741fa77d567dcfbe7ef2d5";
const PRIV_C: &str = "ce6dbb15ddca4c5c3bf66436ecf1eef7e34716e651786adcc74ce4cf5d1f9fd2";
const PRIV_D: &str = "e84a68f089bfb95e57c509711d05ef236c985940909e5b3940004400c928deb9";
const PRIV_E: &str = "e24f4f02ea2270a0cc36bbb251440084b293edb4223feebcff0fab6bd58bd02a";
const PUB_A: &str = "1607af185dcf7c552cb8d37af5cdef29df26dde738a9da2d9f878e3c0d85e442";
const PUB_B: &str = "636d16b0ab3f2c9d25a0aaf191415ab0817b82e51b3bcadeb325eaeac4b41802";
const PUB_C: &str = "98588776857bb65085098455072f56d20913e56783e7a453421cf3d88d514261";
const PUB_D: &str = "fa6e6eea0dd830a407aa184d5a56736965b19abcee7ccfec55b4099f8478860c";
const PUB_E: &str = "1fe0fafc670b454663ebef01d2e4821960a3cf12591e726d4e44e6fbdbfc076c";
#[test]
fn can_create_a_tree() {
let mut rt = RatchetKeyTree::new();
rt.add(DhNode::new(Some(123)))
.expect("First node to be able to add");
assert_eq!(None, rt.prev_tree_secret);
assert_eq!(1, rt.tree.leaf_count());
assert_eq!(1, rt.tree.nodes.len());
assert_eq!(Some(123), rt.tree.nodes[0].value.user_id);
}
#[test]
fn can_add_a_user() {
let mut rt = RatchetKeyTree::new();
rt.add(DhNode::new(Some(123))).unwrap();
rt.add(DhNode::new(Some(234))).expect("Failed to add user");
assert_eq!(3, rt.tree.nodes.len());
assert_eq!(2, rt.tree.leaf_count());
}
#[test]
fn can_create_a_valid_shared_secret() {
let node_a = DhNode::import(None, Some(PRIV_A), PUB_A);
let node_b = DhNode::import(None, Some(PRIV_B), PUB_B);
let shared_secret = node_a.secret_key.unwrap().diffie_hellman(&node_b.pub_key);
assert_eq!(
"57d7c061f48818e811bd07662b263b45d081b4919e7ff54dceb05ca62f3dd47e",
hex::encode(shared_secret.as_bytes())
);
}
}

477
src/full_binary_tree.rs Normal file
View File

@ -0,0 +1,477 @@
use std::cmp;
use crate::tree_node::TreeNode;
/// Full binary tree stored in a Vec
#[derive(Clone)]
pub struct FullBinaryTree<T>
where
T: Clone,
{
/// Tree nodes are stored in array form, 'full' binary tree
///
/// 0
/// / \
/// 1 2 == [0,1,2,3,4,5,6,7,8]
/// / \ / \
/// 3 45 6
/// / \
/// 7 8
/// So the node at [0] equals the top node and the last n nodes are the leaves.
pub nodes: Vec<TreeNode<T>>,
}
impl<T: Clone> FullBinaryTree<T> {
pub fn new() -> Self {
Self { nodes: Vec::new() }
}
pub fn height(&self) -> usize {
// floor(log2(n)) + 1
(self.nodes.len().ilog2() + 1)
.try_into()
.expect("Failed to parse u32s to usize")
}
pub fn leaf_count(&self) -> usize {
// floor(n / 2) + n mod 2
let node_count = self.nodes.len();
node_count / 2 + node_count % 2
}
pub fn leaves(&self) -> &[TreeNode<T>] {
&self.nodes[self.leaf_count() - 1..]
}
pub fn leaves_mut(&mut self) -> &mut [TreeNode<T>] {
let leaf_count = self.leaf_count();
&mut self.nodes[leaf_count - 1..]
}
pub fn index(&self, leaf_index: usize) -> usize {
leaf_index + (self.nodes.len() - self.leaf_count())
}
pub fn parent_index(&self, index: usize) -> Option<usize> {
if index == 0 {
None
} else {
// floor((index - 1) / 2)
Some((index - 1) / 2)
}
}
pub fn sibling_index(&self, index: usize) -> Option<usize> {
if index == 0 {
return None;
}
// (index & 1) == 1 ? index + 1 : max(0, index - 1)
let idx = if index & 1 == 1 {
index + 1
} else {
cmp::max(0, index - 1)
};
if idx > self.nodes.len() - 1 {
None
} else {
Some(idx)
}
}
pub fn left_child_index(&self, parent_index: usize) -> Option<usize> {
let left_index = 2 * parent_index + 1;
if left_index < self.nodes.len() {
Some(left_index)
} else {
None
}
}
pub fn right_child_index(&self, parent_index: usize) -> Option<usize> {
let right_index = 2 * parent_index + 2;
if right_index < self.nodes.len() {
Some(right_index)
} else {
None
}
}
pub fn first_expired_leaf_index(&self) -> Option<usize> {
self.leaves()
.iter()
.enumerate()
.filter(|(_, n)| n.is_expired)
.map(|(i, _)| self.index(i))
.next()
}
pub fn iter(&self, mode: TreeIteratorMode) -> TreeIterator<T> {
TreeIterator { tree: &self, mode }
}
pub fn add(&mut self, value: T) -> usize {
// If expired node exists, replace it
if self.nodes.len() > 0 {
let mut expired_nodes = self
.leaves_mut()
.iter_mut()
.enumerate()
.filter(|(_, n)| n.is_expired);
if let Some((idx, node)) = expired_nodes.next() {
node.value = value;
node.is_expired = false;
return self.index(idx);
}
}
let len = self.nodes.len();
self.nodes.push(TreeNode {
value,
is_expired: false,
});
// Push parent node to balance tree
if len > 0 {
let idx = self.parent_index(len).unwrap();
let parent_node = self.nodes[idx].clone();
self.nodes.push(parent_node);
}
len
}
// Removes last or second last expired node from tree
// Returns index of where the other node is placed
pub fn remove_last_expired_leaf(&mut self) -> Option<usize> {
if self.nodes.len() <= 1 {
panic!("No reason to delete any node, just throw away your tree");
}
let index = self.index(self.leaf_count() - 1);
let sibling_index = self.sibling_index(index).expect("Expected sibling");
let node_to_keep = match (
self.nodes[index].is_expired,
self.nodes[sibling_index].is_expired,
) {
(true, _) => sibling_index,
(false, true) => index,
(false, false) => return None,
};
let parent_index = self.parent_index(index).expect("Expected parents");
self.nodes.swap(parent_index, node_to_keep);
self.nodes.truncate(self.nodes.len() - 2);
Some(parent_index)
}
// Returns true if cleanup of last 2 nodes is possible
pub fn suitable_for_clean_up(&self, index: usize) -> bool {
if index >= self.nodes.len() - 2 {
let sibling_index = self.sibling_index(index).expect("Expected sibling");
if self.nodes[sibling_index].is_expired {
return true;
}
} else if self.nodes.iter().rev().take(2).all(|n| n.is_expired) {
return true;
}
false
}
// Find best location to move to if we want to update a node (housekeeping)
// Returns (old_location if changed, new_location)
pub fn find_best_update_location(&self, index: usize) -> (Option<usize>, usize) {
if self.suitable_for_clean_up(index) {
(
// If last 2 nodes, and my sibling became expired, move up.
// If last 2 nodes are both expired, move up to last 2 expired nodes
Some(index),
self.parent_index(self.nodes.len() - 1)
.expect("Expected parent"),
)
} else if matches!(self.first_expired_leaf_index(), Some(expired_index) if expired_index < index)
{
// If other child more to the left became expired, take that place
(Some(index), self.first_expired_leaf_index().unwrap())
} else {
// Keep in place
(None, index)
}
}
}
pub enum TreeIteratorMode {
Parent(usize),
Finished,
}
pub struct TreeIterator<'a, T: Clone> {
tree: &'a FullBinaryTree<T>,
mode: TreeIteratorMode,
}
impl<'a, T: Clone> Iterator for TreeIterator<'a, T> {
type Item = &'a TreeNode<T>;
fn next(&mut self) -> Option<Self::Item> {
match self.mode {
TreeIteratorMode::Parent(index) => {
let result = self.tree.nodes.get(index);
let parent = self.tree.parent_index(index);
self.mode = match parent {
Some(index) => TreeIteratorMode::Parent(index),
None => TreeIteratorMode::Finished,
};
result
}
TreeIteratorMode::Finished => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_tree(node_count: i32) -> FullBinaryTree<i32> {
let mut nodes = Vec::new();
for i in 0..node_count {
nodes.push(TreeNode {
value: i,
is_expired: false,
})
}
FullBinaryTree::<i32> { nodes }
}
fn actual_leaves(tree: &FullBinaryTree<i32>) -> Vec<i32> {
let actual_leaves: Vec<i32> = tree.leaves().iter().map(|l| l.value).collect();
actual_leaves
}
fn print_tree(tree: &FullBinaryTree<i32>) {
println!(
"Tree: {:?}",
tree.nodes.iter().map(|f| f.value).collect::<Vec<i32>>()
);
println!("Leaves: {:?}", actual_leaves(&tree));
}
#[test]
fn tree_height_can_be_determined() {
let tree = test_tree(7);
assert_eq!(3, tree.height());
assert_eq!(vec![3, 4, 5, 6], actual_leaves(&tree));
assert_eq!(2, test_tree(3).height());
assert_eq!(4, test_tree(8).height());
assert_eq!(4, test_tree(10).height());
assert_eq!(4, test_tree(15).height());
assert_eq!(5, test_tree(20).height());
assert_eq!(5, test_tree(30).height());
}
#[test]
fn leaf_count_can_be_determined() {
assert_eq!(2, test_tree(3).leaf_count());
assert_eq!(3, test_tree(6).leaf_count());
assert_eq!(4, test_tree(7).leaf_count());
assert_eq!(8, test_tree(15).leaf_count());
}
#[test]
fn index_can_be_determined_from_leaf_index() {
assert_eq!(1, test_tree(3).index(0));
assert_eq!(2, test_tree(3).index(1));
assert_eq!(4, test_tree(7).index(1));
assert_eq!(5, test_tree(7).index(2));
}
#[test]
fn parent_index_can_be_determined() {
assert_eq!(None, test_tree(7).parent_index(0));
assert_eq!(Some(0), test_tree(7).parent_index(2));
assert_eq!(Some(2), test_tree(7).parent_index(5));
assert_eq!(Some(2), test_tree(7).parent_index(5));
assert_eq!(Some(3), test_tree(9).parent_index(8));
}
#[test]
fn sibling_index_can_be_determined() {
assert_eq!(None, test_tree(10).sibling_index(9));
assert_eq!(None, test_tree(10).sibling_index(0));
assert_eq!(Some(7), test_tree(10).sibling_index(8));
assert_eq!(Some(8), test_tree(10).sibling_index(7));
assert_eq!(Some(6), test_tree(10).sibling_index(5));
}
#[test]
fn child_indexes_can_be_determined() {
assert_eq!(Some(1), test_tree(10).left_child_index(0));
assert_eq!(Some(2), test_tree(10).right_child_index(0));
assert_eq!(Some(7), test_tree(10).left_child_index(3));
assert_eq!(Some(8), test_tree(10).right_child_index(3));
assert_eq!(Some(9), test_tree(10).left_child_index(4));
assert_eq!(None, test_tree(10).right_child_index(4));
}
#[test]
fn add_adds_node_and_balances_tree() {
/*
[0,1,2]
0
1 2
when 3 is added in a non-full binary tree that would be: [0,1,2,3]
0
1 2
3
What actually happens is [0,1,2,3,1]
Results in
0
1 2
3 1
Adding 4 results in
[0,1,2,3,1,4,2]
0
1 2
3 1 4 2
Adding 5 results in
[0,1,2,3,1,4,2,5,3]
0
1 2
3 1 4 2
5 3
*/
let mut tree = test_tree(3);
let idx = tree.add(3);
assert_eq!(vec![2, 3, 1], actual_leaves(&tree));
assert_eq!(3, tree.nodes[idx].value);
let idx = tree.add(4);
assert_eq!(vec![3, 1, 4, 2], actual_leaves(&tree));
assert_eq!(4, tree.nodes[idx].value);
let idx = tree.add(5);
assert_eq!(vec![1, 4, 2, 5, 3], actual_leaves(&tree));
assert_eq!(5, tree.nodes[idx].value);
}
#[test]
fn is_able_to_remove_last_leaf() {
/*
Deleting nodes is complex
Given the following tree:
0
1 2
3 4 5 6
7 8
Only last leaves can be removed
When 7 is removed, replace parent with 8
Then remove last 2 items from array
This is ofcourse only possible if the sibling can do a key update immediately
0
1 2
8 4 5 6
The next node to be able to removed is 6 or 5
If we remove 6
0
1 5
8 4
*/
let mut tree = test_tree(9);
assert_eq!(vec![4, 5, 6, 7, 8], actual_leaves(&tree));
tree.nodes[7].is_expired = true; // remove second last
assert_eq!(Some(3), tree.remove_last_expired_leaf());
assert_eq!(tree.nodes[3].value, 8);
assert_eq!(vec![8, 4, 5, 6], actual_leaves(&tree));
tree.nodes[6].is_expired = true; // remove last
assert_eq!(Some(2), tree.remove_last_expired_leaf());
assert_eq!(tree.nodes[2].value, 5);
assert_eq!(vec![5, 8, 4], actual_leaves(&tree));
}
#[test]
fn does_not_remove_expired_leaf_is_not_last_or_second_last() {
let mut tree = test_tree(7);
tree.nodes[3].is_expired = true;
tree.nodes[4].is_expired = true;
assert_eq!(None, tree.remove_last_expired_leaf());
assert_eq!(vec![3, 4, 5, 6], actual_leaves(&tree));
}
#[test]
fn eventually_leaves_one_node_in_tree_if_nearly_all_leafs_are_expired() {
let mut tree = test_tree(7);
tree.nodes[3].is_expired = true;
tree.nodes[4].is_expired = false;
tree.nodes[5].is_expired = true;
tree.nodes[6].is_expired = true;
assert_eq!(Some(2), tree.remove_last_expired_leaf());
assert_eq!(vec![5, 3, 4], actual_leaves(&tree));
assert_eq!(Some(1), tree.remove_last_expired_leaf());
assert_eq!(vec![4, 5], actual_leaves(&tree));
assert_eq!(Some(0), tree.remove_last_expired_leaf());
assert_eq!(vec![4], actual_leaves(&tree));
}
#[test]
fn suitable_for_cleanup_returns_correct_value() {
let mut example = test_tree(7);
assert_eq!(false, example.suitable_for_clean_up(6));
example.nodes[5].is_expired = true;
assert_eq!(false, example.suitable_for_clean_up(5));
assert_eq!(true, example.suitable_for_clean_up(6));
example.nodes[4].is_expired = true;
assert_eq!(false, example.suitable_for_clean_up(3));
example.nodes[6].is_expired = true;
assert_eq!(true, example.suitable_for_clean_up(3));
}
#[test]
fn is_able_to_find_best_update_location() {
/*
0
1 2
3x 4 5x 6
*/
// Case 1: no node to replace
let mut example = test_tree(7);
assert_eq!((None, 6), example.find_best_update_location(6));
// Case 2: last or second last node, sibling is expired
example.nodes[5].is_expired = true;
assert_eq!((Some(6), 2), example.find_best_update_location(6));
// Case 3: node more to the left is expired
example.nodes[3].is_expired = true;
assert_eq!((Some(4), 3), example.find_best_update_location(4));
// Case 4: 2 last nodes are both expired (cleanup)
example.nodes[6].is_expired = true;
assert_eq!((Some(4), 2), example.find_best_update_location(4));
}
}

7
src/main.rs Normal file
View File

@ -0,0 +1,7 @@
mod dh_node;
mod dh_tree;
mod full_binary_tree;
mod tree_error;
mod tree_node;
fn main() {}

9
src/tree_error.rs Normal file
View File

@ -0,0 +1,9 @@
#[derive(Debug)]
pub enum RatchetTreeError {
NodeNotFound(usize),
SiblingNodeNotFound(usize),
NotAllowedToReplaceNode,
NotAllowedToDeleteNode,
NotEnoughKeysToUpdate,
MissingPrivateKey,
}

8
src/tree_node.rs Normal file
View File

@ -0,0 +1,8 @@
#[derive(Clone)]
pub struct TreeNode<T>
where
T: Clone,
{
pub value: T,
pub is_expired: bool,
}