From 53b2c23ec4d2260c930d6403b04a6564c0a36245 Mon Sep 17 00:00:00 2001 From: Özgür Kesim Date: Thu, 14 Nov 2024 21:54:14 +0100 Subject: stage2: fix logic error for lost case --- nizk/stage2.go | 2 +- nizk/stage2_test.go | 64 +++++++++++++++++++++++++++++------------------------ 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/nizk/stage2.go b/nizk/stage2.go index f38475d..8cda33e 100644 --- a/nizk/stage2.go +++ b/nizk/stage2.go @@ -128,7 +128,7 @@ func (b *Bit) RevealStage2(lost bool, prev *Bit, Xs ...*Point) (rv2 *StageReveal ch := Challenge(points...) pr = &Stage2Proof{} - if !prev.IsSet() { + if lost { pr.Ch[0] = ω[0] pr.Ch[1] = ω[1] pr.Ch[2] = ch.Sub(ω[0]).Sub(ω[1]) diff --git a/nizk/stage2_test.go b/nizk/stage2_test.go index cfd6e13..2e5ac06 100644 --- a/nizk/stage2_test.go +++ b/nizk/stage2_test.go @@ -1,44 +1,50 @@ package nizk import ( - "fmt" "testing" . "kesim.org/seal/common" ) -func TestStage2Simple(t *testing.T) { +func TestStage2Simple1(t *testing.T) { id := Curve.RandomScalar() - b1, _, _ := NewBit(id, false) // This is also the junction - c1 := b1.StageCommit() - r1, _ := b1.RevealStage1() - - // Because the first index is a junction, any subsequent - // combination of Bits must verify with 'lost' set to true - // in the RevealStage2 calls. - for _, s := range [][2]bool{ - {false, false}, - {true, false}, - {false, true}, - {true, true}, - } { - b2, bc2, _ := NewBit(id, s[0]) - b3, bc3, _ := NewBit(id, s[1]) - - c2 := b2.StageCommit() - c3 := b3.StageCommit() - t.Run(fmt.Sprintf("variant %t %t b2.b1", s[0], s[1]), func(t *testing.T) { - r2, p2 := b2.RevealStage2(true, b1) // We had lost previously + + for _, lost := range []bool{true, false} { + b1, _, _ := NewBit(id, !lost) + c1 := b1.StageCommit() + r1, _ := b1.RevealStage1() + + // Because the first index is a junction, any subsequent + // combination of Bits must verify with 'lost' set to true + // in the RevealStage2 calls. + for _, s := range [][2]bool{ + {false, false}, + {true, false}, + {false, true}, + {true, true}, + } { + b2, bc2, _ := NewBit(id, s[0]) + b3, bc3, _ := NewBit(id, s[1]) + b4, bc4, _ := NewBit(id, s[1]) // same as b3 + + c2 := b2.StageCommit() + c3 := b3.StageCommit() + c4 := b4.StageCommit() + + r2, p2 := b2.RevealStage2(lost, b1) if !bc2.VerifyStage2(c1, c2, r1, r2, p2) { - t.Fatalf("failed to verify bc2") + t.Fatalf("failed to verify b2: %t b3: %t bc2/b1", s[0], s[1]) } - }) - t.Run(fmt.Sprintf("variant %t %t, b3.b1", s[0], s[1]), func(t *testing.T) { - r3, p3 := b3.RevealStage2(true, b1) // We had lost previously + + r3, p3 := b3.RevealStage2(lost, b1) if !bc3.VerifyStage2(c1, c3, r1, r3, p3) { - t.Fatalf("failed to verify bc3") + t.Fatalf("failed to verify b1: %t b3: %t bc3/b1", s[0], s[1]) } - }) - } + r4, p4 := b4.RevealStage2(lost, b1) + if !bc4.VerifyStage2(c1, c4, r1, r4, p4) { + t.Fatalf("failed to verify b1: %t b4: %t bc4/b1", s[0], s[1]) + } + } + } } -- cgit v1.2.3