diff --git a/src/utils/field.rs b/src/utils/field.rs index 37ceb74..f0da8f0 100644 --- a/src/utils/field.rs +++ b/src/utils/field.rs @@ -1,13 +1,104 @@ -use std::ops::{Add, Mul}; +use std::{ + env::args, + ops::{Add, Mul}, +}; use anyhow::{anyhow, Ok, Result}; use base64::prelude::*; +use serde_json::Value; use super::{math::xor_bytes, poly::gfmul}; +#[derive(Debug)] +pub struct Polynomial { + polynomial: Vec, +} + +impl Polynomial { + pub const fn new(polynomial: Vec) -> Self { + Self { polynomial } + } + + pub fn from_c_array(array: &Value) -> Self { + let mut polynomial: Vec = vec![]; + let c_array: Vec = array + .as_array() + .expect("Input is not an array") + .iter() + .map(|x| { + x.as_str() + .expect("Array element is not a string") + .to_string() + }) + .collect(); + + eprintln!("{:?}", c_array); + + for coefficient in c_array { + polynomial.push(FieldElement::new( + BASE64_STANDARD + .decode(coefficient) + .expect("Error on poly decode:"), + )); + } + Self { polynomial } + } + + pub fn to_c_array(self) -> Vec { + let mut output: Vec = vec![]; + for coeff in self.polynomial { + output.push(BASE64_STANDARD.encode(coeff)); + } + + output + } +} + +impl Mul for Polynomial { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + let mut polynomial: Vec = + vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1]; + for i in 0..self.polynomial.len() { + for j in 0..rhs.polynomial.len() { + polynomial[i + j] = &polynomial[i + j] + + &(self.polynomial.get(i).unwrap() * rhs.polynomial.get(j).unwrap()); + } + } + Polynomial::new(polynomial) + } +} + +impl Add for Polynomial { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + FieldElement::new( + xor_bytes(&self.field_element, rhs.field_element).expect("Error in poly add"), + ) + } +} + +/* +impl Add for Polynomial { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { + FieldElement::new( + xor_bytes(&self.field_element, rhs.field_element).expect("Error in poly add"), + ) + } +} + +impl AsRef<[u8]> for Polynomial { + fn as_ref(&self) -> &[u8] { + &self.field_element.as_ref() + } +} +*/ + #[derive(Debug)] pub struct FieldElement { - polynomial: Vec, + field_element: Vec, } impl FieldElement { @@ -15,8 +106,8 @@ impl FieldElement { 87, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 01, ]; - pub const fn new(polynomial: Vec) -> Self { - Self { polynomial } + pub const fn new(field_element: Vec) -> Self { + Self { field_element } } pub fn mul(&self, poly_a: Vec, poly_b: Vec) -> Result> { @@ -29,7 +120,19 @@ impl Mul for FieldElement { fn mul(self, rhs: Self) -> Self::Output { FieldElement::new( - gfmul(self.polynomial, rhs.polynomial, "gcm").expect("Error during multiplication"), + gfmul(self.field_element, rhs.field_element, "gcm") + .expect("Error during multiplication"), + ) + } +} + +impl Mul for &FieldElement { + type Output = FieldElement; + + fn mul(self, rhs: &FieldElement) -> FieldElement { + FieldElement::new( + gfmul(self.field_element.clone(), rhs.field_element.clone(), "gcm") + .expect("Error during multiplication"), ) } } @@ -37,13 +140,32 @@ impl Mul for FieldElement { impl Add for FieldElement { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - FieldElement::new(xor_bytes(&self.polynomial, rhs.polynomial).expect("Error in poly add")) + FieldElement::new( + xor_bytes(&self.field_element, rhs.field_element).expect("Error in poly add"), + ) + } +} + +impl Add for &FieldElement { + type Output = FieldElement; + fn add(self, rhs: Self) -> Self::Output { + FieldElement::new( + xor_bytes(&self.field_element, rhs.field_element.clone()).expect("Error in poly add"), + ) } } impl AsRef<[u8]> for FieldElement { fn as_ref(&self) -> &[u8] { - &self.polynomial.as_ref() + &self.field_element.as_ref() + } +} + +impl Clone for FieldElement { + fn clone(&self) -> Self { + FieldElement { + field_element: self.field_element.clone(), + } } } @@ -158,6 +280,7 @@ impl ByteArray { mod tests { use super::*; use base64::prelude::*; + use serde_json::json; use std::fs; #[test] @@ -263,4 +386,53 @@ mod tests { assert_eq!(BASE64_STANDARD.encode(sum), "H1d3GuyA9/0OxeYouUpAAA=="); } + + #[test] + fn test_field_add_02() { + let element1: FieldElement = + FieldElement::new(BASE64_STANDARD.decode("NeverGonnaLetYouDownAA==").unwrap()); + let element2: FieldElement = + FieldElement::new(BASE64_STANDARD.decode("DHBWMannheimAAAAAAAAAA==").unwrap()); + let sum = element2 + element1; + + assert_eq!(BASE64_STANDARD.encode(sum), "OZuIncPAGEp4tYouDownAA=="); + } + + #[test] + fn test_field_add_03() { + let json1 = json!([ + "NeverGonnaGiveYouUpAAA==", + "NeverGonnaLetYouDownAA==", + "NeverGonnaRunAroundAAA==", + "AndDesertYouAAAAAAAAAA==" + ]); + let json2 = json!(["KryptoanalyseAAAAAAAAA==", "DHBWMannheimAAAAAAAAAA=="]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + let element2: Polynomial = Polynomial::from_c_array(&json2); + + let sum = element2 + element1; + + assert_eq!(BASE64_STANDARD.encode(sum), "OZuIncPAGEp4tYouDownAA=="); + } + + #[test] + fn test_field_mul_01() { + let json1 = json!([ + "JAAAAAAAAAAAAAAAAAAAAA==", + "wAAAAAAAAAAAAAAAAAAAAA==", + "ACAAAAAAAAAAAAAAAAAAAA==" + ]); + let json2 = json!(["0AAAAAAAAAAAAAAAAAAAAA==", "IQAAAAAAAAAAAAAAAAAAAA=="]); + let element1: Polynomial = Polynomial::from_c_array(&json1); + let element2: Polynomial = Polynomial::from_c_array(&json2); + + //eprintln!("{:?}", element1); + + let result = element1 * element2; + + eprintln!("Result = {:?}", result.to_c_array()); + + assert!(false); + //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); + } }