Add basic pfmath functionality #13

Merged
0xalivecow merged 7 commits from dev into main 2024-11-14 22:12:02 +00:00
Showing only changes of commit 6e33e2e44c - Show all commits

View file

@ -1,13 +1,104 @@
use std::ops::{Add, Mul}; use std::{
env::args,
ops::{Add, Mul},
};
use anyhow::{anyhow, Ok, Result}; use anyhow::{anyhow, Ok, Result};
use base64::prelude::*; use base64::prelude::*;
use serde_json::Value;
use super::{math::xor_bytes, poly::gfmul}; use super::{math::xor_bytes, poly::gfmul};
#[derive(Debug)]
pub struct Polynomial {
polynomial: Vec<FieldElement>,
}
impl Polynomial {
pub const fn new(polynomial: Vec<FieldElement>) -> Self {
Self { polynomial }
}
pub fn from_c_array(array: &Value) -> Self {
let mut polynomial: Vec<FieldElement> = vec![];
let c_array: Vec<String> = array
.as_array()
.expect("Input is not an array")
.iter()
.map(|x| {
x.as_str()
.expect("Array element is not a string")
.to_string()
})
.collect();
eprintln!("{:?}", c_array);
for coefficient in c_array {
polynomial.push(FieldElement::new(
BASE64_STANDARD
.decode(coefficient)
.expect("Error on poly decode:"),
));
}
Self { polynomial }
}
pub fn to_c_array(self) -> Vec<String> {
let mut output: Vec<String> = vec![];
for coeff in self.polynomial {
output.push(BASE64_STANDARD.encode(coeff));
}
output
}
}
impl Mul for Polynomial {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
let mut polynomial: Vec<FieldElement> =
vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1];
for i in 0..self.polynomial.len() {
for j in 0..rhs.polynomial.len() {
polynomial[i + j] = &polynomial[i + j]
+ &(self.polynomial.get(i).unwrap() * rhs.polynomial.get(j).unwrap());
}
}
Polynomial::new(polynomial)
}
}
impl Add for Polynomial {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
FieldElement::new(
xor_bytes(&self.field_element, rhs.field_element).expect("Error in poly add"),
)
}
}
/*
impl Add for Polynomial {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
FieldElement::new(
xor_bytes(&self.field_element, rhs.field_element).expect("Error in poly add"),
)
}
}
impl AsRef<[u8]> for Polynomial {
fn as_ref(&self) -> &[u8] {
&self.field_element.as_ref()
}
}
*/
#[derive(Debug)] #[derive(Debug)]
pub struct FieldElement { pub struct FieldElement {
polynomial: Vec<u8>, field_element: Vec<u8>,
} }
impl FieldElement { impl FieldElement {
@ -15,8 +106,8 @@ impl FieldElement {
87, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 01, 87, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 00, 01,
]; ];
pub const fn new(polynomial: Vec<u8>) -> Self { pub const fn new(field_element: Vec<u8>) -> Self {
Self { polynomial } Self { field_element }
} }
pub fn mul(&self, poly_a: Vec<u8>, poly_b: Vec<u8>) -> Result<Vec<u8>> { pub fn mul(&self, poly_a: Vec<u8>, poly_b: Vec<u8>) -> Result<Vec<u8>> {
@ -29,7 +120,19 @@ impl Mul for FieldElement {
fn mul(self, rhs: Self) -> Self::Output { fn mul(self, rhs: Self) -> Self::Output {
FieldElement::new( FieldElement::new(
gfmul(self.polynomial, rhs.polynomial, "gcm").expect("Error during multiplication"), gfmul(self.field_element, rhs.field_element, "gcm")
.expect("Error during multiplication"),
)
}
}
impl Mul for &FieldElement {
type Output = FieldElement;
fn mul(self, rhs: &FieldElement) -> FieldElement {
FieldElement::new(
gfmul(self.field_element.clone(), rhs.field_element.clone(), "gcm")
.expect("Error during multiplication"),
) )
} }
} }
@ -37,13 +140,32 @@ impl Mul for FieldElement {
impl Add for FieldElement { impl Add for FieldElement {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self::Output { fn add(self, rhs: Self) -> Self::Output {
FieldElement::new(xor_bytes(&self.polynomial, rhs.polynomial).expect("Error in poly add")) FieldElement::new(
xor_bytes(&self.field_element, rhs.field_element).expect("Error in poly add"),
)
}
}
impl Add for &FieldElement {
type Output = FieldElement;
fn add(self, rhs: Self) -> Self::Output {
FieldElement::new(
xor_bytes(&self.field_element, rhs.field_element.clone()).expect("Error in poly add"),
)
} }
} }
impl AsRef<[u8]> for FieldElement { impl AsRef<[u8]> for FieldElement {
fn as_ref(&self) -> &[u8] { fn as_ref(&self) -> &[u8] {
&self.polynomial.as_ref() &self.field_element.as_ref()
}
}
impl Clone for FieldElement {
fn clone(&self) -> Self {
FieldElement {
field_element: self.field_element.clone(),
}
} }
} }
@ -158,6 +280,7 @@ impl ByteArray {
mod tests { mod tests {
use super::*; use super::*;
use base64::prelude::*; use base64::prelude::*;
use serde_json::json;
use std::fs; use std::fs;
#[test] #[test]
@ -263,4 +386,53 @@ mod tests {
assert_eq!(BASE64_STANDARD.encode(sum), "H1d3GuyA9/0OxeYouUpAAA=="); assert_eq!(BASE64_STANDARD.encode(sum), "H1d3GuyA9/0OxeYouUpAAA==");
} }
#[test]
fn test_field_add_02() {
let element1: FieldElement =
FieldElement::new(BASE64_STANDARD.decode("NeverGonnaLetYouDownAA==").unwrap());
let element2: FieldElement =
FieldElement::new(BASE64_STANDARD.decode("DHBWMannheimAAAAAAAAAA==").unwrap());
let sum = element2 + element1;
assert_eq!(BASE64_STANDARD.encode(sum), "OZuIncPAGEp4tYouDownAA==");
}
#[test]
fn test_field_add_03() {
let json1 = json!([
"NeverGonnaGiveYouUpAAA==",
"NeverGonnaLetYouDownAA==",
"NeverGonnaRunAroundAAA==",
"AndDesertYouAAAAAAAAAA=="
]);
let json2 = json!(["KryptoanalyseAAAAAAAAA==", "DHBWMannheimAAAAAAAAAA=="]);
let element1: Polynomial = Polynomial::from_c_array(&json1);
let element2: Polynomial = Polynomial::from_c_array(&json2);
let sum = element2 + element1;
assert_eq!(BASE64_STANDARD.encode(sum), "OZuIncPAGEp4tYouDownAA==");
}
#[test]
fn test_field_mul_01() {
let json1 = json!([
"JAAAAAAAAAAAAAAAAAAAAA==",
"wAAAAAAAAAAAAAAAAAAAAA==",
"ACAAAAAAAAAAAAAAAAAAAA=="
]);
let json2 = json!(["0AAAAAAAAAAAAAAAAAAAAA==", "IQAAAAAAAAAAAAAAAAAAAA=="]);
let element1: Polynomial = Polynomial::from_c_array(&json1);
let element2: Polynomial = Polynomial::from_c_array(&json2);
//eprintln!("{:?}", element1);
let result = element1 * element2;
eprintln!("Result = {:?}", result.to_c_array());
assert!(false);
//assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA==");
}
} }