refactor: Change initialisations of field elements to be cleaner

This commit is contained in:
Alivecow 2024-11-29 19:31:10 +01:00
parent bf4c3ee4ca
commit 2623bd9a8d

View file

@ -32,15 +32,13 @@ impl Polynomial {
} }
pub fn one() -> Self { pub fn one() -> Self {
Polynomial::new(vec![FieldElement::new( Polynomial::new(vec![FieldElement::one()])
polynomial_2_block(vec![0], "gcm").unwrap(),
)])
} }
pub fn x() -> Self { pub fn x() -> Self {
Polynomial::new(vec![ Polynomial::new(vec![
FieldElement::new(vec![0; 16]), FieldElement::new(vec![0; 16]),
FieldElement::new(polynomial_2_block(vec![0], "gcm").unwrap()), FieldElement::new(polynomial_2_block(vec![0], "xex").unwrap()),
]) ])
} }
@ -54,7 +52,7 @@ impl Polynomial {
} }
pub fn zero() -> Self { pub fn zero() -> Self {
Polynomial::new(vec![FieldElement::new(vec![0; 16])]) Polynomial::new(vec![FieldElement::zero()])
} }
pub fn from_c_array(array: &Value) -> Self { pub fn from_c_array(array: &Value) -> Self {
@ -83,7 +81,7 @@ impl Polynomial {
pub fn to_c_array(self) -> Vec<String> { pub fn to_c_array(self) -> Vec<String> {
let mut output: Vec<String> = vec![]; let mut output: Vec<String> = vec![];
for coeff in self.polynomial { for coeff in self.polynomial {
output.push(BASE64_STANDARD.encode(coeff)); output.push(coeff.to_b64());
} }
output output
@ -209,9 +207,7 @@ impl Polynomial {
} }
if exponent == 0 { if exponent == 0 {
let result = Polynomial::new(vec![FieldElement::new( let result = Polynomial::new(vec![FieldElement::one()]);
polynomial_2_block(vec![0], "gcm").unwrap(),
)]);
return result; return result;
} }
@ -258,10 +254,7 @@ impl Polynomial {
//eprintln!("{:?}, {:?}", self.polynomial.len(), rhs.polynomial.len()); //eprintln!("{:?}, {:?}", self.polynomial.len(), rhs.polynomial.len());
if self.polynomial.len() < rhs.polynomial.len() { if self.polynomial.len() < rhs.polynomial.len() {
return ( return (Polynomial::new(vec![FieldElement::zero()]), self.clone());
Polynomial::new(vec![FieldElement::new(vec![0; 16])]),
self.clone(),
);
} }
let mut remainder = self.clone(); let mut remainder = self.clone();
@ -270,16 +263,10 @@ impl Polynomial {
let divisor_deg = divisor.polynomial.len() - 1; let divisor_deg = divisor.polynomial.len() - 1;
if dividend_deg < divisor_deg { if dividend_deg < divisor_deg {
return ( return (Polynomial::new(vec![FieldElement::zero()]), remainder);
Polynomial::new(vec![FieldElement::new(
polynomial_2_block(vec![0; 16], "gcm").unwrap(),
)]),
remainder,
);
} }
let mut quotient_coeffs = let mut quotient_coeffs = vec![FieldElement::zero(); dividend_deg - divisor_deg + 1];
vec![FieldElement::new(vec![0; 16]); dividend_deg - divisor_deg + 1];
while remainder.polynomial.len() >= divisor.polynomial.len() { while remainder.polynomial.len() >= divisor.polynomial.len() {
let deg_diff = remainder.polynomial.len() - divisor.polynomial.len(); let deg_diff = remainder.polynomial.len() - divisor.polynomial.len();
@ -290,7 +277,7 @@ impl Polynomial {
quotient_coeffs[deg_diff] = quot_coeff.clone(); quotient_coeffs[deg_diff] = quot_coeff.clone();
let mut subtrahend = vec![FieldElement::new(vec![0; 16]); deg_diff]; let mut subtrahend = vec![FieldElement::zero(); deg_diff];
subtrahend.extend( subtrahend.extend(
divisor divisor
.polynomial .polynomial
@ -315,7 +302,7 @@ impl Polynomial {
} }
if remainder.is_empty() { if remainder.is_empty() {
remainder = Polynomial::new(vec![FieldElement::new(vec![0; 16])]); remainder = Polynomial::new(vec![FieldElement::zero()]);
} }
(Polynomial::new(quotient_coeffs), remainder) (Polynomial::new(quotient_coeffs), remainder)
} }
@ -416,10 +403,10 @@ impl Mul for Polynomial {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Self::Output { fn mul(self, rhs: Self) -> Self::Output {
if self.is_zero() || rhs.is_zero() { if self.is_zero() || rhs.is_zero() {
return Polynomial::new(vec![FieldElement::new(vec![0; 16])]); return Polynomial::zero();
} }
let mut polynomial: Vec<FieldElement> = let mut polynomial: Vec<FieldElement> =
vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1]; vec![FieldElement::zero(); self.polynomial.len() + rhs.polynomial.len() - 1];
for i in 0..self.polynomial.len() { for i in 0..self.polynomial.len() {
for j in 0..rhs.polynomial.len() { for j in 0..rhs.polynomial.len() {
polynomial[i + j] = &polynomial[i + j] polynomial[i + j] = &polynomial[i + j]
@ -434,10 +421,10 @@ impl Mul for &Polynomial {
type Output = Polynomial; type Output = Polynomial;
fn mul(self, rhs: Self) -> Self::Output { fn mul(self, rhs: Self) -> Self::Output {
if self.is_zero() || rhs.is_zero() { if self.is_zero() || rhs.is_zero() {
return Polynomial::new(vec![FieldElement::new(vec![0])]); return Polynomial::zero();
} }
let mut polynomial: Vec<FieldElement> = let mut polynomial: Vec<FieldElement> =
vec![FieldElement::new(vec![0; 16]); self.polynomial.len() + rhs.polynomial.len() - 1]; vec![FieldElement::zero(); self.polynomial.len() + rhs.polynomial.len() - 1];
for i in 0..self.polynomial.len() { for i in 0..self.polynomial.len() {
for j in 0..rhs.polynomial.len() { for j in 0..rhs.polynomial.len() {
polynomial[i + j] = &polynomial[i + j] polynomial[i + j] = &polynomial[i + j]
@ -471,7 +458,7 @@ impl Add for Polynomial {
} }
if polynomial.is_empty() { if polynomial.is_empty() {
return Polynomial::new(vec![FieldElement::new(vec![0; 16])]); return Polynomial::new(vec![FieldElement::zero()]);
} }
Polynomial::new(polynomial) Polynomial::new(polynomial)
@ -514,8 +501,8 @@ impl PartialOrd for Polynomial {
self.as_ref().iter().rev().zip(other.as_ref().iter().rev()) self.as_ref().iter().rev().zip(other.as_ref().iter().rev())
{ {
match field_a match field_a
.reverse_bits() //.reverse_bits()
.partial_cmp(&field_b.reverse_bits()) .partial_cmp(&field_b)
.unwrap() .unwrap()
{ {
Ordering::Equal => continue, Ordering::Equal => continue,
@ -538,7 +525,10 @@ impl Ord for Polynomial {
for (field_a, field_b) in for (field_a, field_b) in
self.as_ref().iter().rev().zip(other.as_ref().iter().rev()) self.as_ref().iter().rev().zip(other.as_ref().iter().rev())
{ {
match field_a.reverse_bits().cmp(&field_b.reverse_bits()) { match field_a
//.reverse_bits()
.cmp(&field_b)
{
Ordering::Equal => continue, Ordering::Equal => continue,
other => return other, other => return other,
} }
@ -1115,19 +1105,6 @@ mod tests {
//assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA=="); //assert_eq!(BASE64_STANDARD.encode(product), "MoAAAAAAAAAAAAAAAAAAAA==");
} }
#[test]
fn test_poly_div_01() {
let element1 =
FieldElement::new(BASE64_STANDARD.decode("JAAAAAAAAAAAAAAAAAAAAA==").unwrap());
let element2 =
FieldElement::new(BASE64_STANDARD.decode("wAAAAAAAAAAAAAAAAAAAAA==").unwrap());
let result = element1 / element2;
assert_eq!(BASE64_STANDARD.encode(result), "OAAAAAAAAAAAAAAAAAAAAA==");
}
#[test] #[test]
fn test_field_poly_div_01() { fn test_field_poly_div_01() {
let json1 = json!([ let json1 = json!([