aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorÖzgür Kesim <oec@codeblau.de>2024-11-14 21:54:14 +0100
committerÖzgür Kesim <oec@codeblau.de>2024-11-14 21:54:14 +0100
commit53b2c23ec4d2260c930d6403b04a6564c0a36245 (patch)
tree04132e1b220877aafd37a600ea92bd0f1d132106
parent38ab8f84c71ef2448fbbd10652fd4068cfbc3a31 (diff)
stage2: fix logic error for lost case
-rw-r--r--nizk/stage2.go2
-rw-r--r--nizk/stage2_test.go64
2 files changed, 36 insertions, 30 deletions
diff --git a/nizk/stage2.go b/nizk/stage2.go
index f38475d..8cda33e 100644
--- a/nizk/stage2.go
+++ b/nizk/stage2.go
@@ -128,7 +128,7 @@ func (b *Bit) RevealStage2(lost bool, prev *Bit, Xs ...*Point) (rv2 *StageReveal
ch := Challenge(points...)
pr = &Stage2Proof{}
- if !prev.IsSet() {
+ if lost {
pr.Ch[0] = ω[0]
pr.Ch[1] = ω[1]
pr.Ch[2] = ch.Sub(ω[0]).Sub(ω[1])
diff --git a/nizk/stage2_test.go b/nizk/stage2_test.go
index cfd6e13..2e5ac06 100644
--- a/nizk/stage2_test.go
+++ b/nizk/stage2_test.go
@@ -1,44 +1,50 @@
package nizk
import (
- "fmt"
"testing"
. "kesim.org/seal/common"
)
-func TestStage2Simple(t *testing.T) {
+func TestStage2Simple1(t *testing.T) {
id := Curve.RandomScalar()
- b1, _, _ := NewBit(id, false) // This is also the junction
- c1 := b1.StageCommit()
- r1, _ := b1.RevealStage1()
-
- // Because the first index is a junction, any subsequent
- // combination of Bits must verify with 'lost' set to true
- // in the RevealStage2 calls.
- for _, s := range [][2]bool{
- {false, false},
- {true, false},
- {false, true},
- {true, true},
- } {
- b2, bc2, _ := NewBit(id, s[0])
- b3, bc3, _ := NewBit(id, s[1])
-
- c2 := b2.StageCommit()
- c3 := b3.StageCommit()
- t.Run(fmt.Sprintf("variant %t %t b2.b1", s[0], s[1]), func(t *testing.T) {
- r2, p2 := b2.RevealStage2(true, b1) // We had lost previously
+
+ for _, lost := range []bool{true, false} {
+ b1, _, _ := NewBit(id, !lost)
+ c1 := b1.StageCommit()
+ r1, _ := b1.RevealStage1()
+
+ // Because the first index is a junction, any subsequent
+ // combination of Bits must verify with 'lost' set to true
+ // in the RevealStage2 calls.
+ for _, s := range [][2]bool{
+ {false, false},
+ {true, false},
+ {false, true},
+ {true, true},
+ } {
+ b2, bc2, _ := NewBit(id, s[0])
+ b3, bc3, _ := NewBit(id, s[1])
+ b4, bc4, _ := NewBit(id, s[1]) // same as b3
+
+ c2 := b2.StageCommit()
+ c3 := b3.StageCommit()
+ c4 := b4.StageCommit()
+
+ r2, p2 := b2.RevealStage2(lost, b1)
if !bc2.VerifyStage2(c1, c2, r1, r2, p2) {
- t.Fatalf("failed to verify bc2")
+ t.Fatalf("failed to verify b2: %t b3: %t bc2/b1", s[0], s[1])
}
- })
- t.Run(fmt.Sprintf("variant %t %t, b3.b1", s[0], s[1]), func(t *testing.T) {
- r3, p3 := b3.RevealStage2(true, b1) // We had lost previously
+
+ r3, p3 := b3.RevealStage2(lost, b1)
if !bc3.VerifyStage2(c1, c3, r1, r3, p3) {
- t.Fatalf("failed to verify bc3")
+ t.Fatalf("failed to verify b1: %t b3: %t bc3/b1", s[0], s[1])
}
- })
- }
+ r4, p4 := b4.RevealStage2(lost, b1)
+ if !bc4.VerifyStage2(c1, c4, r1, r4, p4) {
+ t.Fatalf("failed to verify b1: %t b4: %t bc4/b1", s[0], s[1])
+ }
+ }
+ }
}