aboutsummaryrefslogtreecommitdiff
path: root/nizk
diff options
context:
space:
mode:
Diffstat (limited to 'nizk')
-rw-r--r--nizk/bench_test.go212
-rw-r--r--nizk/commit.go12
-rw-r--r--nizk/stage1.go1
-rw-r--r--nizk/stage2_test.go43
4 files changed, 247 insertions, 21 deletions
diff --git a/nizk/bench_test.go b/nizk/bench_test.go
new file mode 100644
index 0000000..f823394
--- /dev/null
+++ b/nizk/bench_test.go
@@ -0,0 +1,212 @@
+package nizk
+
+import (
+ "log"
+ "math/rand"
+ "slices"
+ "sync"
+ "testing"
+
+ . "kesim.org/seal/common"
+)
+
+func BenchmarkFromPaper(b *testing.B) {
+ bitlength := 5
+ vals := []int{
+ 0b01010,
+ 0b01001,
+ 0b00111,
+ }
+ ids := []Bytes{
+ Curve.RandomScalar(),
+ Curve.RandomScalar(),
+ Curve.RandomScalar(),
+ }
+
+ for range b.N {
+ var bits = [3][]*Bit{}
+ for i, b := range vals {
+ bits[i] = Int2Bits(ids[i], b, bitlength)
+ }
+
+ var lost = [3]bool{}
+ instage1 := true
+ junction := -1
+ result := 0
+
+ for idx := 0; idx < bitlength; idx++ {
+ var c = [3]*StageCommitment{}
+ var r = [3]*StageReveal{}
+
+ for i, b := range bits {
+ c[i] = b[idx].StageCommit()
+ }
+
+ if instage1 {
+ var p = [3]*Stage1Proof{}
+ for i := range bits {
+ r[i], p[i] = bits[i][idx].RevealStage1(c[0].X, c[1].X, c[2].X)
+ if !bits[i][idx].Commitment.VerifyStage1(c[i], r[i], p[i]) {
+ b.Fatalf("bits[%d][%d] commitment failed to verify in stage1", i, idx)
+ }
+ }
+
+ Z := Curve.Product(r[0].Z, r[1].Z, r[2].Z)
+ if !Id.Equal(Z) {
+ junction = idx
+ instage1 = false
+
+ for i := range bits {
+ if !lost[i] && !bits[i][idx].IsSet() {
+ lost[i] = true
+ }
+ }
+ result |= 1 << (bitlength - 1 - idx)
+ }
+ } else {
+ var bj = [3]*Bit{}
+ for i := range bits {
+ bj[i] = bits[i][junction]
+ }
+
+ var p = [3]*Stage2Proof{}
+ for i := range bits {
+ r[i], p[i] = bits[i][idx].RevealStage2(lost[i], bj[i], c[0].X, c[1].X, c[2].X)
+ if !bits[i][idx].Commitment.VerifyStage2(bj[i].StageCommitment, c[i], bj[i].StageReveal, r[i], p[i]) {
+ b.Fatalf("bits[%d][%d] commitment failed to verify in stage2, result so far: %05b", i, idx, result)
+ }
+ }
+
+ Z := Curve.Product(r[0].Z, r[1].Z, r[2].Z)
+ if !Id.Equal(Z) {
+ junction = idx
+
+ for i := range bits {
+ if !lost[i] && !bits[i][idx].IsSet() {
+ lost[i] = true
+ }
+ }
+ result |= 1 << (bitlength - 1 - idx)
+ }
+ }
+ }
+ if result != vals[0] {
+ b.Fatalf("wrong result: %05b, expected: %05b", result, vals[0])
+ }
+ }
+}
+
+func runSeal(n int, bitlength int) {
+ var vals = make([]int, n)
+ var ids = make([]Bytes, n)
+ for i := range n {
+ vals[i] = rand.Intn(1<<(bitlength-1) - 1)
+ ids[i] = Curve.RandomScalar()
+ }
+ max := slices.Max(vals)
+
+ var bits = make([][]*Bit, n)
+ for i, b := range vals {
+ bits[i] = Int2Bits(ids[i], b, bitlength)
+ }
+
+ var c = make([]*StageCommitment, n)
+ var Xs = make([]*Point, n)
+ var r = make([]*StageReveal, n)
+ var p1 = make([]*Stage1Proof, n)
+ var Zs = make([]*Point, n)
+ var bj = make([]*Bit, n)
+ var p2 = make([]*Stage2Proof, n)
+ var lost = make([]bool, n)
+ instage1 := true
+ junction := -1
+ result := 0
+
+ for idx := range bitlength {
+ for i := range n {
+ c[i] = bits[i][idx].StageCommit()
+ Xs[i] = c[i].X
+ }
+ if instage1 {
+ var wg sync.WaitGroup
+ wg.Add(n)
+
+ for i := range n {
+ go func() {
+ r[i], p1[i] = bits[i][idx].RevealStage1(Xs...)
+ if !bits[i][idx].Commitment.VerifyStage1(c[i], r[i], p1[i]) {
+ log.Fatalf("bits[%d][%d] commitment failed to verify in stage1", i, idx)
+ }
+ Zs[i] = r[i].Z
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ Z := Curve.Product(Zs...)
+ if !Id.Equal(Z) {
+ junction = idx
+ instage1 = false
+ for i := range bits {
+ if !lost[i] && !bits[i][idx].IsSet() {
+ lost[i] = true
+ }
+ }
+ result |= 1 << (bitlength - 1 - idx)
+ }
+ } else {
+ for i := range bits {
+ bj[i] = bits[i][junction]
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(n)
+ for i := range n {
+ go func() {
+ r[i], p2[i] = bits[i][idx].RevealStage2(lost[i], bj[i], Xs...)
+ if !bits[i][idx].Commitment.VerifyStage2(bj[i].StageCommitment, c[i], bj[i].StageReveal, r[i], p2[i]) {
+ log.Fatalf("bits[%d][%d] commitment failed to verify in stage2, result so far: %05b", i, idx, result)
+ }
+ Zs[i] = r[i].Z
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ Z := Curve.Product(Zs...)
+ if !Id.Equal(Z) {
+ junction = idx
+
+ for i := range n {
+ if !lost[i] && !bits[i][idx].IsSet() {
+ lost[i] = true
+ }
+ }
+ result |= 1 << (bitlength - 1 - idx)
+ }
+ }
+ }
+ if result != max {
+ log.Fatalf("wrong result: %0[1]*[2]b, expected: %0[1]*[3]b", bitlength, result, max)
+ }
+}
+
+func benchmarkMulti(n int, bitlength int, b *testing.B) {
+ for range b.N {
+ runSeal(n, bitlength)
+ }
+}
+
+func BenchmarkRun3on5bit(b *testing.B) { benchmarkMulti(3, 5, b) }
+func BenchmarkRun10on8bit(b *testing.B) { benchmarkMulti(10, 8, b) }
+func BenchmarkRun10on16bit(b *testing.B) { benchmarkMulti(10, 16, b) }
+func BenchmarkRun10on24bit(b *testing.B) { benchmarkMulti(10, 24, b) }
+func BenchmarkRun100on8bit(b *testing.B) { benchmarkMulti(100, 8, b) }
+func BenchmarkRun100on16bit(b *testing.B) { benchmarkMulti(100, 16, b) }
+func BenchmarkRun100on24bit(b *testing.B) { benchmarkMulti(100, 24, b) }
+func BenchmarkRun500on8bit(b *testing.B) { benchmarkMulti(500, 8, b) }
+func BenchmarkRun500on16bit(b *testing.B) { benchmarkMulti(500, 16, b) }
+func BenchmarkRun500on24bit(b *testing.B) { benchmarkMulti(500, 24, b) }
+func BenchmarkRun1000on8bit(b *testing.B) { benchmarkMulti(1000, 8, b) }
+func BenchmarkRun1000on16bit(b *testing.B) { benchmarkMulti(1000, 16, b) }
+func BenchmarkRun1000on24bit(b *testing.B) { benchmarkMulti(1000, 24, b) }
diff --git a/nizk/commit.go b/nizk/commit.go
index 93c730f..27e31f7 100644
--- a/nizk/commit.go
+++ b/nizk/commit.go
@@ -55,6 +55,18 @@ func NewBitFromScalars(id Bytes, set bool, α, β *Scalar) *Bit {
return b
}
+func Int2Bits(id Bytes, val int, bitlength int) []*Bit {
+ if bitlength < 0 || bitlength > 24 {
+ return nil
+ }
+
+ bits := make([]*Bit, bitlength)
+ for i := range bitlength {
+ bits[i] = NewBit(id, (val>>(bitlength-i-1))&1 != 0)
+ }
+ return bits
+}
+
func (b *Bit) IsSet() bool {
return b.set
}
diff --git a/nizk/stage1.go b/nizk/stage1.go
index e7ed44d..f58a4ac 100644
--- a/nizk/stage1.go
+++ b/nizk/stage1.go
@@ -98,6 +98,7 @@ func (b *Bit) reveal(prev_true bool, Xs ...*Point) (r *StageReveal) {
s.Sent = true
} else {
r.Z = Y.Exp(s.x)
+ s.Sent = false
}
return r
diff --git a/nizk/stage2_test.go b/nizk/stage2_test.go
index 4e6232e..f147bd6 100644
--- a/nizk/stage2_test.go
+++ b/nizk/stage2_test.go
@@ -1,6 +1,7 @@
package nizk
import (
+ "slices"
"testing"
. "kesim.org/seal/common"
@@ -49,32 +50,30 @@ func TestStage2Simple1(t *testing.T) {
}
}
-func uint2bits(bid int) []*Bit {
+func int2bits(bid int, bitlength int) []*Bit {
id := Curve.RandomScalar()
- return []*Bit{
- NewBit(id, (bid>>4)&1 != 0),
- NewBit(id, (bid>>3)&1 != 0),
- NewBit(id, (bid>>2)&1 != 0),
- NewBit(id, (bid>>1)&1 != 0),
- NewBit(id, (bid>>0)&1 != 0),
+ bits := make([]*Bit, bitlength)
+ for i := range bitlength {
+ bits[i] = NewBit(id, (bid>>bitlength-i)&1 != 0)
}
+ return bits
}
-var Id = Curve.Identity()
-
func TestStage2Complex(t *testing.T) {
bid1 := 0b0101
bid2 := 0b0010
t.Logf("testing bid1: %05b vs. bid2: %05b", bid1, bid2)
- bits1 := uint2bits(bid1)
- bits2 := uint2bits(bid2)
+ bitlength := 4
+
+ bits1 := int2bits(bid1, bitlength)
+ bits2 := int2bits(bid2, bitlength)
lost1 := false
lost2 := false
- if len(bits1) != len(bits2) || len(bits1) != 5 {
+ if len(bits1) != len(bits2) || len(bits1) != bitlength {
t.Fatalf("oops")
}
@@ -117,7 +116,7 @@ func TestStage2Complex(t *testing.T) {
t.Logf("setting lost2 to true")
lost2 = true
}
- result |= 1 << (4 - c)
+ result |= 1 << (bitlength - 1 - c)
} else {
t.Logf("Z[%d] == Id, staying in stage1", c)
}
@@ -151,7 +150,7 @@ func TestStage2Complex(t *testing.T) {
t.Logf("setting lost2 to true")
lost2 = true
}
- result |= 1 << (4 - c)
+ result |= 1 << (bitlength - 1 - c)
}
}
}
@@ -161,9 +160,10 @@ func TestStage2Complex(t *testing.T) {
}
func TestFromPaper(t *testing.T) {
+ bitlength := 5
vals := []int{
- 0b01010,
0b01001,
+ 0b01010,
0b00111,
}
@@ -171,7 +171,7 @@ func TestFromPaper(t *testing.T) {
var bits = [3][]*Bit{}
for i, b := range vals {
- bits[i] = uint2bits(b)
+ bits[i] = int2bits(b, bitlength)
}
var lost = [3]bool{}
@@ -179,7 +179,7 @@ func TestFromPaper(t *testing.T) {
junction := -1
result := 0
- for idx := 0; idx < 5; idx++ {
+ for idx := 0; idx < bitlength; idx++ {
var c = [3]*StageCommitment{}
var r = [3]*StageReveal{}
@@ -212,7 +212,7 @@ func TestFromPaper(t *testing.T) {
t.Logf("bit %d, set lost[%d] to true, so far: %v", idx, i, lost)
}
}
- result |= 1 << (4 - idx)
+ result |= 1 << (bitlength - 1 - idx)
} else {
t.Logf("Z[%d] == Id, staying in stage1", idx)
}
@@ -249,12 +249,13 @@ func TestFromPaper(t *testing.T) {
t.Logf("bits[%d][%d], set lost[%d] to true, so far: %v", i, idx, i, lost)
}
}
- result |= 1 << (4 - idx)
+ result |= 1 << (bitlength - 1 - idx)
}
}
}
- if result != vals[0] {
- t.Fatalf("wrong result: %05b", result)
+ max := slices.Max(vals)
+ if result != max {
+ t.Fatalf("wrong result: %05b, expected: %05b", result, max)
}
}