diff options
Diffstat (limited to 'vickrey_test.go')
-rw-r--r-- | vickrey_test.go | 209 |
1 files changed, 209 insertions, 0 deletions
diff --git a/vickrey_test.go b/vickrey_test.go new file mode 100644 index 0000000..e36a579 --- /dev/null +++ b/vickrey_test.go @@ -0,0 +1,209 @@ +package seal + +import ( + "math/rand" + "slices" + "strings" + "sync" + "sync/atomic" + "testing" + + . "kesim.org/seal/common" +) + +type stage int + +const ( + stage1 stage = iota + stage2 +) + +func (s stage) String() string { + if s == stage1 { + return "stage1" + } + return "stage2" +} + +func runVickrey(n int, bitlength int, tb testing.TB) { + var vals = make([]int, n) + var ids = make([]Bytes, n) + for i := range n { + vals[i] = rand.Intn(1<<(bitlength-1) - 1) + ids[i] = Curve.RandomScalar() + } + max := slices.Max(vals) + max_idx := slices.Index(vals, max) + max2 := slices.Max(slices.Delete(slices.Clone(vals), max_idx, max_idx+1)) + if max == max2 { + max_idx = -1 + } + + tb.Logf("running vickrey for vals:\n%0[1]*[2]b\nmax: %0[1]*[3]b, max2: %0[1]*[4]b, winner: %d\n", bitlength, vals, max, max2, max_idx) + + var bits = make([][]*Bit, n) + for i, b := range vals { + bits[i] = Int2Bits(ids[i], b, bitlength) + } + + var c = make([]*StageCommitment, n) + var Xs = make([]*Point, n) + var r = make([]*StageReveal, n) + var p1 = make([]*Stage1Proof, n) + var Zs = make([]*Point, n) + var bj = make([]*Bit, n) + var p2 = make([]*Stage2Proof, n) + var lost = make([]bool, n) + stage := stage1 + junction := -1 + winner := -1 + result := 0 + + isWinner := func(Z *Point, i int, idx int) bool { + z := Z.Div(Zs[i]) + xu := Curve.Identity() + xl := Curve.Identity() + found := false + + for k := range n { + if k == winner { + continue + } + if k < i { + xu = xu.Mul(Xs[k]) + } else if k > i { + xl = xl.Mul(Xs[k]) + } + } + xu = xu.Exp(bits[i][idx].Stage.x) + xl = xl.Exp(bits[i][idx].Stage.x) + x := xu.Div(xl) + + if x.Equal(z) { + tb.Logf("equal by value") + found = true + } + + if winner < 0 { + // BUG! TODO! + s1 := x.String() + s2 := z.String() + if strings.HasPrefix(s1, s2[:len(s2)-2]) { + tb.Logf("BUG! TODO! equal only by string") + found = true + } + } + // tb.Logf("testing max_idx %d, i %d, bit %d:\n%v vs %v", max_idx, i, idx, x, z) + return found + } + + for idx := range bitlength { + for i := range n { + if i == winner { + Xs[i] = Id + Zs[i] = Id + continue + } + c[i] = bits[i][idx].StageCommit() + Xs[i] = c[i].X + if stage == stage2 { + bj[i] = bits[i][junction] + } + } + + var wg sync.WaitGroup + wg.Add(n) + fail := &atomic.Bool{} + for i := range n { + if i == winner { + tb.Logf("%s, idx: %d, skipping winner %d", stage, idx, winner) + wg.Done() + continue + } + go func(i int) { + defer wg.Done() + if stage == stage1 { + r[i], p1[i] = bits[i][idx].RevealStage1(Xs...) + if !bits[i][idx].Commitment.VerifyStage1(c[i], r[i], p1[i]) { + fail.Store(true) + tb.Fatalf("bits[i: %d][idx: %d] commitment failed to verify in stage1", i, idx) + } + } else { + r[i], p2[i] = bits[i][idx].RevealStage2(lost[i], bj[i], Xs...) + if !bits[i][idx].Commitment.VerifyStage2(bj[i].StageCommitment, c[i], bj[i].StageReveal, r[i], p2[i]) { + fail.Store(true) + tb.Fatalf("bits[i: %d][idx: %d] (junction: %d) verify failed in stage2, lost: %t, result so far: %05b\nXs: %v", i, idx, junction, lost[i], result, Xs) + } + } + Zs[i] = r[i].Z + }(i) + } + wg.Wait() + if fail.Load() { + tb.Fail() + return + } + + Z := Curve.Product(Zs...) + tb.Logf("Z[idx: %d]: %v", idx, Z) + reset := false + if !Id.Equal(Z) { + var lost_round = make([]bool, n) + for i := range n { + if i == winner { + continue + } + if !lost[i] && !bits[i][idx].IsSet() { + lost_round[i] = true + } + + // Winner test + if winner < 0 && !lost_round[i] && !lost[i] { + reset = isWinner(Z, i, idx) + if reset { + tb.Logf("found winner %d %s, idx %d", i, stage, idx) + winner = i + // stage = stage1 + break + } + } + } + if !reset { + result |= 1 << (bitlength - 1 - idx) + junction = idx + stage = stage2 + for i := range n { + lost[i] = lost[i] || lost_round[i] + } + } + } + tb.Logf("lost: %t, result: %08b", lost, result) + } + if result != max2 { + tb.Fatalf("wrong result: %0[1]*[2]b, exp. max2: %0[1]*[3]b, max: %0[1]*[4]b\nvals: %0[1]*[5]b", bitlength, result, max2, max, vals) + } + if max_idx != winner { + tb.Fatalf("wrong winner, max_idx: %d vs winner: %d val %08b\nvals: %08b", max_idx, winner, max2, vals) + } +} + +func TestSeal4on6bit(t *testing.T) { runSeal(4, 6, t) } +func TestSeal100on24bit(t *testing.T) { + if testing.Short() { + t.Skip("skipping vickrey 100, 16") + } + runSeal(100, 24, t) +} + +func TestVickrey4on6bit(t *testing.T) { runVickrey(4, 6, t) } +func TestVickrey100on16bit(t *testing.T) { + if testing.Short() { + t.Skip("skipping vickrey 100, 16") + } + runVickrey(100, 16, t) +} +func BenchmarkVickrey100on24bit(b *testing.B) { + for range b.N { + runVickrey(100, 24, b) + } +} |