diff options
-rw-r--r-- | common/common.go | 1 | ||||
-rw-r--r-- | nizk/bench_test.go | 212 | ||||
-rw-r--r-- | nizk/commit.go | 12 | ||||
-rw-r--r-- | nizk/stage1.go | 1 | ||||
-rw-r--r-- | nizk/stage2_test.go | 43 |
5 files changed, 248 insertions, 21 deletions
diff --git a/common/common.go b/common/common.go index d72daaf..4dbbd92 100644 --- a/common/common.go +++ b/common/common.go @@ -14,6 +14,7 @@ type Point = curve.Curve25519Point var Curve = curve.Curve25519 var G = Curve.Generator() var One = Curve.ScalarOne() +var Id = Curve.Identity() type Bytes interface { Bytes() []byte diff --git a/nizk/bench_test.go b/nizk/bench_test.go new file mode 100644 index 0000000..f823394 --- /dev/null +++ b/nizk/bench_test.go @@ -0,0 +1,212 @@ +package nizk + +import ( + "log" + "math/rand" + "slices" + "sync" + "testing" + + . "kesim.org/seal/common" +) + +func BenchmarkFromPaper(b *testing.B) { + bitlength := 5 + vals := []int{ + 0b01010, + 0b01001, + 0b00111, + } + ids := []Bytes{ + Curve.RandomScalar(), + Curve.RandomScalar(), + Curve.RandomScalar(), + } + + for range b.N { + var bits = [3][]*Bit{} + for i, b := range vals { + bits[i] = Int2Bits(ids[i], b, bitlength) + } + + var lost = [3]bool{} + instage1 := true + junction := -1 + result := 0 + + for idx := 0; idx < bitlength; idx++ { + var c = [3]*StageCommitment{} + var r = [3]*StageReveal{} + + for i, b := range bits { + c[i] = b[idx].StageCommit() + } + + if instage1 { + var p = [3]*Stage1Proof{} + for i := range bits { + r[i], p[i] = bits[i][idx].RevealStage1(c[0].X, c[1].X, c[2].X) + if !bits[i][idx].Commitment.VerifyStage1(c[i], r[i], p[i]) { + b.Fatalf("bits[%d][%d] commitment failed to verify in stage1", i, idx) + } + } + + Z := Curve.Product(r[0].Z, r[1].Z, r[2].Z) + if !Id.Equal(Z) { + junction = idx + instage1 = false + + for i := range bits { + if !lost[i] && !bits[i][idx].IsSet() { + lost[i] = true + } + } + result |= 1 << (bitlength - 1 - idx) + } + } else { + var bj = [3]*Bit{} + for i := range bits { + bj[i] = bits[i][junction] + } + + var p = [3]*Stage2Proof{} + for i := range bits { + r[i], p[i] = bits[i][idx].RevealStage2(lost[i], bj[i], c[0].X, c[1].X, c[2].X) + if !bits[i][idx].Commitment.VerifyStage2(bj[i].StageCommitment, c[i], bj[i].StageReveal, r[i], p[i]) { + b.Fatalf("bits[%d][%d] commitment failed to verify in stage2, result so far: %05b", i, idx, result) + } + } + + Z := Curve.Product(r[0].Z, r[1].Z, r[2].Z) + if !Id.Equal(Z) { + junction = idx + + for i := range bits { + if !lost[i] && !bits[i][idx].IsSet() { + lost[i] = true + } + } + result |= 1 << (bitlength - 1 - idx) + } + } + } + if result != vals[0] { + b.Fatalf("wrong result: %05b, expected: %05b", result, vals[0]) + } + } +} + +func runSeal(n int, bitlength int) { + var vals = make([]int, n) + var ids = make([]Bytes, n) + for i := range n { + vals[i] = rand.Intn(1<<(bitlength-1) - 1) + ids[i] = Curve.RandomScalar() + } + max := slices.Max(vals) + + var bits = make([][]*Bit, n) + for i, b := range vals { + bits[i] = Int2Bits(ids[i], b, bitlength) + } + + var c = make([]*StageCommitment, n) + var Xs = make([]*Point, n) + var r = make([]*StageReveal, n) + var p1 = make([]*Stage1Proof, n) + var Zs = make([]*Point, n) + var bj = make([]*Bit, n) + var p2 = make([]*Stage2Proof, n) + var lost = make([]bool, n) + instage1 := true + junction := -1 + result := 0 + + for idx := range bitlength { + for i := range n { + c[i] = bits[i][idx].StageCommit() + Xs[i] = c[i].X + } + if instage1 { + var wg sync.WaitGroup + wg.Add(n) + + for i := range n { + go func() { + r[i], p1[i] = bits[i][idx].RevealStage1(Xs...) + if !bits[i][idx].Commitment.VerifyStage1(c[i], r[i], p1[i]) { + log.Fatalf("bits[%d][%d] commitment failed to verify in stage1", i, idx) + } + Zs[i] = r[i].Z + wg.Done() + }() + } + wg.Wait() + + Z := Curve.Product(Zs...) + if !Id.Equal(Z) { + junction = idx + instage1 = false + for i := range bits { + if !lost[i] && !bits[i][idx].IsSet() { + lost[i] = true + } + } + result |= 1 << (bitlength - 1 - idx) + } + } else { + for i := range bits { + bj[i] = bits[i][junction] + } + + var wg sync.WaitGroup + wg.Add(n) + for i := range n { + go func() { + r[i], p2[i] = bits[i][idx].RevealStage2(lost[i], bj[i], Xs...) + if !bits[i][idx].Commitment.VerifyStage2(bj[i].StageCommitment, c[i], bj[i].StageReveal, r[i], p2[i]) { + log.Fatalf("bits[%d][%d] commitment failed to verify in stage2, result so far: %05b", i, idx, result) + } + Zs[i] = r[i].Z + wg.Done() + }() + } + wg.Wait() + + Z := Curve.Product(Zs...) + if !Id.Equal(Z) { + junction = idx + + for i := range n { + if !lost[i] && !bits[i][idx].IsSet() { + lost[i] = true + } + } + result |= 1 << (bitlength - 1 - idx) + } + } + } + if result != max { + log.Fatalf("wrong result: %0[1]*[2]b, expected: %0[1]*[3]b", bitlength, result, max) + } +} + +func benchmarkMulti(n int, bitlength int, b *testing.B) { + for range b.N { + runSeal(n, bitlength) + } +} + +func BenchmarkRun3on5bit(b *testing.B) { benchmarkMulti(3, 5, b) } +func BenchmarkRun10on8bit(b *testing.B) { benchmarkMulti(10, 8, b) } +func BenchmarkRun10on16bit(b *testing.B) { benchmarkMulti(10, 16, b) } +func BenchmarkRun10on24bit(b *testing.B) { benchmarkMulti(10, 24, b) } +func BenchmarkRun100on8bit(b *testing.B) { benchmarkMulti(100, 8, b) } +func BenchmarkRun100on16bit(b *testing.B) { benchmarkMulti(100, 16, b) } +func BenchmarkRun100on24bit(b *testing.B) { benchmarkMulti(100, 24, b) } +func BenchmarkRun500on8bit(b *testing.B) { benchmarkMulti(500, 8, b) } +func BenchmarkRun500on16bit(b *testing.B) { benchmarkMulti(500, 16, b) } +func BenchmarkRun500on24bit(b *testing.B) { benchmarkMulti(500, 24, b) } +func BenchmarkRun1000on8bit(b *testing.B) { benchmarkMulti(1000, 8, b) } +func BenchmarkRun1000on16bit(b *testing.B) { benchmarkMulti(1000, 16, b) } +func BenchmarkRun1000on24bit(b *testing.B) { benchmarkMulti(1000, 24, b) } diff --git a/nizk/commit.go b/nizk/commit.go index 93c730f..27e31f7 100644 --- a/nizk/commit.go +++ b/nizk/commit.go @@ -55,6 +55,18 @@ func NewBitFromScalars(id Bytes, set bool, α, β *Scalar) *Bit { return b } +func Int2Bits(id Bytes, val int, bitlength int) []*Bit { + if bitlength < 0 || bitlength > 24 { + return nil + } + + bits := make([]*Bit, bitlength) + for i := range bitlength { + bits[i] = NewBit(id, (val>>(bitlength-i-1))&1 != 0) + } + return bits +} + func (b *Bit) IsSet() bool { return b.set } diff --git a/nizk/stage1.go b/nizk/stage1.go index e7ed44d..f58a4ac 100644 --- a/nizk/stage1.go +++ b/nizk/stage1.go @@ -98,6 +98,7 @@ func (b *Bit) reveal(prev_true bool, Xs ...*Point) (r *StageReveal) { s.Sent = true } else { r.Z = Y.Exp(s.x) + s.Sent = false } return r diff --git a/nizk/stage2_test.go b/nizk/stage2_test.go index 4e6232e..f147bd6 100644 --- a/nizk/stage2_test.go +++ b/nizk/stage2_test.go @@ -1,6 +1,7 @@ package nizk import ( + "slices" "testing" . "kesim.org/seal/common" @@ -49,32 +50,30 @@ func TestStage2Simple1(t *testing.T) { } } -func uint2bits(bid int) []*Bit { +func int2bits(bid int, bitlength int) []*Bit { id := Curve.RandomScalar() - return []*Bit{ - NewBit(id, (bid>>4)&1 != 0), - NewBit(id, (bid>>3)&1 != 0), - NewBit(id, (bid>>2)&1 != 0), - NewBit(id, (bid>>1)&1 != 0), - NewBit(id, (bid>>0)&1 != 0), + bits := make([]*Bit, bitlength) + for i := range bitlength { + bits[i] = NewBit(id, (bid>>bitlength-i)&1 != 0) } + return bits } -var Id = Curve.Identity() - func TestStage2Complex(t *testing.T) { bid1 := 0b0101 bid2 := 0b0010 t.Logf("testing bid1: %05b vs. bid2: %05b", bid1, bid2) - bits1 := uint2bits(bid1) - bits2 := uint2bits(bid2) + bitlength := 4 + + bits1 := int2bits(bid1, bitlength) + bits2 := int2bits(bid2, bitlength) lost1 := false lost2 := false - if len(bits1) != len(bits2) || len(bits1) != 5 { + if len(bits1) != len(bits2) || len(bits1) != bitlength { t.Fatalf("oops") } @@ -117,7 +116,7 @@ func TestStage2Complex(t *testing.T) { t.Logf("setting lost2 to true") lost2 = true } - result |= 1 << (4 - c) + result |= 1 << (bitlength - 1 - c) } else { t.Logf("Z[%d] == Id, staying in stage1", c) } @@ -151,7 +150,7 @@ func TestStage2Complex(t *testing.T) { t.Logf("setting lost2 to true") lost2 = true } - result |= 1 << (4 - c) + result |= 1 << (bitlength - 1 - c) } } } @@ -161,9 +160,10 @@ func TestStage2Complex(t *testing.T) { } func TestFromPaper(t *testing.T) { + bitlength := 5 vals := []int{ - 0b01010, 0b01001, + 0b01010, 0b00111, } @@ -171,7 +171,7 @@ func TestFromPaper(t *testing.T) { var bits = [3][]*Bit{} for i, b := range vals { - bits[i] = uint2bits(b) + bits[i] = int2bits(b, bitlength) } var lost = [3]bool{} @@ -179,7 +179,7 @@ func TestFromPaper(t *testing.T) { junction := -1 result := 0 - for idx := 0; idx < 5; idx++ { + for idx := 0; idx < bitlength; idx++ { var c = [3]*StageCommitment{} var r = [3]*StageReveal{} @@ -212,7 +212,7 @@ func TestFromPaper(t *testing.T) { t.Logf("bit %d, set lost[%d] to true, so far: %v", idx, i, lost) } } - result |= 1 << (4 - idx) + result |= 1 << (bitlength - 1 - idx) } else { t.Logf("Z[%d] == Id, staying in stage1", idx) } @@ -249,12 +249,13 @@ func TestFromPaper(t *testing.T) { t.Logf("bits[%d][%d], set lost[%d] to true, so far: %v", i, idx, i, lost) } } - result |= 1 << (4 - idx) + result |= 1 << (bitlength - 1 - idx) } } } - if result != vals[0] { - t.Fatalf("wrong result: %05b", result) + max := slices.Max(vals) + if result != max { + t.Fatalf("wrong result: %05b, expected: %05b", result, max) } } |