aboutsummaryrefslogtreecommitdiff
path: root/nizk/stage2_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'nizk/stage2_test.go')
-rw-r--r--nizk/stage2_test.go43
1 files changed, 22 insertions, 21 deletions
diff --git a/nizk/stage2_test.go b/nizk/stage2_test.go
index 4e6232e..f147bd6 100644
--- a/nizk/stage2_test.go
+++ b/nizk/stage2_test.go
@@ -1,6 +1,7 @@
package nizk
import (
+ "slices"
"testing"
. "kesim.org/seal/common"
@@ -49,32 +50,30 @@ func TestStage2Simple1(t *testing.T) {
}
}
-func uint2bits(bid int) []*Bit {
+func int2bits(bid int, bitlength int) []*Bit {
id := Curve.RandomScalar()
- return []*Bit{
- NewBit(id, (bid>>4)&1 != 0),
- NewBit(id, (bid>>3)&1 != 0),
- NewBit(id, (bid>>2)&1 != 0),
- NewBit(id, (bid>>1)&1 != 0),
- NewBit(id, (bid>>0)&1 != 0),
+ bits := make([]*Bit, bitlength)
+ for i := range bitlength {
+ bits[i] = NewBit(id, (bid>>bitlength-i)&1 != 0)
}
+ return bits
}
-var Id = Curve.Identity()
-
func TestStage2Complex(t *testing.T) {
bid1 := 0b0101
bid2 := 0b0010
t.Logf("testing bid1: %05b vs. bid2: %05b", bid1, bid2)
- bits1 := uint2bits(bid1)
- bits2 := uint2bits(bid2)
+ bitlength := 4
+
+ bits1 := int2bits(bid1, bitlength)
+ bits2 := int2bits(bid2, bitlength)
lost1 := false
lost2 := false
- if len(bits1) != len(bits2) || len(bits1) != 5 {
+ if len(bits1) != len(bits2) || len(bits1) != bitlength {
t.Fatalf("oops")
}
@@ -117,7 +116,7 @@ func TestStage2Complex(t *testing.T) {
t.Logf("setting lost2 to true")
lost2 = true
}
- result |= 1 << (4 - c)
+ result |= 1 << (bitlength - 1 - c)
} else {
t.Logf("Z[%d] == Id, staying in stage1", c)
}
@@ -151,7 +150,7 @@ func TestStage2Complex(t *testing.T) {
t.Logf("setting lost2 to true")
lost2 = true
}
- result |= 1 << (4 - c)
+ result |= 1 << (bitlength - 1 - c)
}
}
}
@@ -161,9 +160,10 @@ func TestStage2Complex(t *testing.T) {
}
func TestFromPaper(t *testing.T) {
+ bitlength := 5
vals := []int{
- 0b01010,
0b01001,
+ 0b01010,
0b00111,
}
@@ -171,7 +171,7 @@ func TestFromPaper(t *testing.T) {
var bits = [3][]*Bit{}
for i, b := range vals {
- bits[i] = uint2bits(b)
+ bits[i] = int2bits(b, bitlength)
}
var lost = [3]bool{}
@@ -179,7 +179,7 @@ func TestFromPaper(t *testing.T) {
junction := -1
result := 0
- for idx := 0; idx < 5; idx++ {
+ for idx := 0; idx < bitlength; idx++ {
var c = [3]*StageCommitment{}
var r = [3]*StageReveal{}
@@ -212,7 +212,7 @@ func TestFromPaper(t *testing.T) {
t.Logf("bit %d, set lost[%d] to true, so far: %v", idx, i, lost)
}
}
- result |= 1 << (4 - idx)
+ result |= 1 << (bitlength - 1 - idx)
} else {
t.Logf("Z[%d] == Id, staying in stage1", idx)
}
@@ -249,12 +249,13 @@ func TestFromPaper(t *testing.T) {
t.Logf("bits[%d][%d], set lost[%d] to true, so far: %v", i, idx, i, lost)
}
}
- result |= 1 << (4 - idx)
+ result |= 1 << (bitlength - 1 - idx)
}
}
}
- if result != vals[0] {
- t.Fatalf("wrong result: %05b", result)
+ max := slices.Max(vals)
+ if result != max {
+ t.Fatalf("wrong result: %05b, expected: %05b", result, max)
}
}