/* This file is part of libbrandt.
 * Copyright (C) 2016 GNUnet e.V.
 *
 * libbrandt is free software: you can redistribute it and/or modify it under
 * the terms of the GNU General Public License as published by the Free Software
 * Foundation, either version 3 of the License, or (at your option) any later
 * version.
 *
 * libbrandt is distributed in the hope that it will be useful, but WITHOUT ANY
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
 * A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along with
 * libbrandt.  If not, see <http://www.gnu.org/licenses/>.
 */

/**
 * @file fp_pub.c
 * @brief Implementation of the first price public outcome algorithm.
 * @author Markus Teich
 */

#include "platform.h"

#include <gcrypt.h>

#include "crypto.h"
#include "internals.h"
#include "util.h"


void
fp_pub_prep_outcome (struct BRANDT_Auction *ad)
{
	gcry_mpi_t       coeff = gcry_mpi_copy (GCRYMPI_CONST_ONE);
	gcry_mpi_point_t tmp = gcry_mpi_point_new (0);
	gcry_mpi_point_t *tlta1;
	gcry_mpi_point_t *tltb1;
	gcry_mpi_point_t **tlta2;
	gcry_mpi_point_t **tltb2;

	ad->gamma2 = smc_init2 (ad->n, ad->k);
	brandt_assert (ad->gamma2);

	ad->delta2 = smc_init2 (ad->n, ad->k);
	brandt_assert (ad->delta2);

	ad->tmpa1 = smc_init1 (ad->k);
	brandt_assert (ad->tmpa1);

	ad->tmpb1 = smc_init1 (ad->k);
	brandt_assert (ad->tmpb1);

	/* create temporary lookup tables with partial sums */
	tlta1 = smc_init1 (ad->k);
	tltb1 = smc_init1 (ad->k);
	tlta2 = smc_init2 (ad->n, ad->k);
	tltb2 = smc_init2 (ad->n, ad->k);

	/* temporary lookup table for sum of bid vectors */
	for (uint16_t i = 0; i < ad->n; i++)
	{
		smc_sums_partial (tlta2[i], ad->alpha[i], ad->k, 1, 1);
		smc_sums_partial (tltb2[i], ad->beta[i], ad->k, 1, 1);
		for (uint16_t j = 0; j < ad->k; j++)
		{
			gcry_mpi_ec_sub (tlta2[i][j],
			                 tlta2[i][ad->k - 1],
			                 tlta2[i][j],
			                 ec_ctx);
			gcry_mpi_ec_sub (tltb2[i][j],
			                 tltb2[i][ad->k - 1],
			                 tltb2[i][j],
			                 ec_ctx);
		}
		brandt_assert (!ec_point_cmp (ec_zero, tlta2[i][ad->k - 1]));
		brandt_assert (!ec_point_cmp (ec_zero, tltb2[i][ad->k - 1]));
	}
	for (uint16_t j = 0; j < ad->k; j++)
	{
		smc_sum (tlta1[j], &tlta2[0][j], ad->n, ad->k);
		smc_sum (tltb1[j], &tltb2[0][j], ad->n, ad->k);
	}
	smc_free2 (tlta2, ad->n, ad->k);
	smc_free2 (tltb2, ad->n, ad->k);
	brandt_assert (!ec_point_cmp (ec_zero, tlta1[ad->k - 1]));
	brandt_assert (!ec_point_cmp (ec_zero, tltb1[ad->k - 1]));

	/* initialize tmp array with zeroes, since we are calculating a sum */
	for (uint16_t j = 0; j < ad->k; j++)
	{
		ec_point_copy (ad->tmpa1[j], ec_zero);
		ec_point_copy (ad->tmpb1[j], ec_zero);
	}
	/* store the \sum_{i=1}^n2^{i-1}b_i in tmp1 until outcome determination,
	 * since it is needed each time a gamma,delta pair is received from another
	 * bidder */
	for (uint16_t i = 0; i < ad->n; i++)
	{
		for (uint16_t j = 0; j < ad->k; j++)
		{
			gcry_mpi_ec_mul (tmp, coeff, ad->alpha[i][j], ec_ctx);
			gcry_mpi_ec_add (ad->tmpa1[j], ad->tmpa1[j], tmp, ec_ctx);
			gcry_mpi_ec_mul (tmp, coeff, ad->beta[i][j], ec_ctx);
			gcry_mpi_ec_add (ad->tmpb1[j], ad->tmpb1[j], tmp, ec_ctx);
		}
		gcry_mpi_lshift (coeff, coeff, 1);
	}

	for (uint16_t j = 0; j < ad->k; j++)
	{
		/* copy unmasked outcome to all other bidder layers so they don't
		 * have to be recomputed to check the ZK proof_2dle's from other
		 * bidders when receiving their outcome messages */
		for (uint16_t a = 0; a < ad->n; a++)
		{
			ec_point_copy (ad->gamma2[a][j], tlta1[j]);
			ec_point_copy (ad->delta2[a][j], tltb1[j]);
		}
	}

	gcry_mpi_release (coeff);
	gcry_mpi_point_release (tmp);
	smc_free1 (tlta1, ad->k);
	smc_free1 (tltb1, ad->k);
}


/**
 * fp_pub_compute_outcome computes the outcome for first price auctions with a
 * public outcome and packs it into a message buffer together with proofs of
 * correctnes.
 *
 * @param[in] ad Pointer to the BRANDT_Auction struct to operate on
 * @param[out] buflen Size of the returned message buffer in bytes
 * @return A buffer containing the encrypted outcome vectors
 * which needs to be broadcast
 */
unsigned char *
fp_pub_compute_outcome (struct BRANDT_Auction *ad, size_t *buflen)
{
	unsigned char     *ret;
	unsigned char     *cur;
	gcry_mpi_point_t  tmpa = gcry_mpi_point_new (0);
	gcry_mpi_point_t  tmpb = gcry_mpi_point_new (0);
	struct msg_head   *head;
	struct ec_mpi     *gamma;
	struct ec_mpi     *delta;
	struct proof_2dle *proof2;

	brandt_assert (ad && buflen);

	*buflen = (sizeof (*head) +
	           ad->k * (sizeof (*gamma) +
	                    sizeof (*delta) +
	                    sizeof (*proof2)));
	ret = GNUNET_new_array (*buflen, unsigned char);

	head = (struct msg_head *)ret;
	head->prot_version = htonl (0);
	head->msg_type = htonl (msg_outcome);
	cur = ret + sizeof (*head);

	for (uint16_t j = 0; j < ad->k; j++)
	{
		gamma = (struct ec_mpi *)cur;
		delta = &((struct ec_mpi *)cur)[1];
		proof2 = (struct proof_2dle *)(cur + 2 * sizeof (struct ec_mpi));

		ec_point_copy (tmpa, ad->gamma2[ad->i][j]);
		ec_point_copy (tmpb, ad->delta2[ad->i][j]);

		/* apply random masking to first summand */
		smc_zkp_2dle (ad->gamma2[ad->i][j],
		              ad->delta2[ad->i][j],
		              tmpa,
		              tmpb,
		              NULL,
		              proof2);

		ec_point_serialize (gamma, ad->gamma2[ad->i][j]);
		ec_point_serialize (delta, ad->delta2[ad->i][j]);

		/* add winner determination for own gamma,delta */
		gcry_mpi_ec_add (ad->gamma2[ad->i][j],
		                 ad->gamma2[ad->i][j],
		                 ad->tmpa1[j],
		                 ec_ctx);
		gcry_mpi_ec_add (ad->delta2[ad->i][j],
		                 ad->delta2[ad->i][j],
		                 ad->tmpb1[j],
		                 ec_ctx);

		cur += sizeof (*gamma) + sizeof (*delta) + sizeof (*proof2);
	}

	gcry_mpi_point_release (tmpa);
	gcry_mpi_point_release (tmpb);
	return ret;
}


int
fp_pub_recv_outcome (struct BRANDT_Auction *ad,
                     const unsigned char   *buf,
                     size_t                buflen,
                     uint16_t              sender)
{
	int                 ret = 0;
	const unsigned char *cur = buf;
	struct proof_2dle   *proof2;
	gcry_mpi_point_t    gamma = gcry_mpi_point_new (0);
	gcry_mpi_point_t    delta = gcry_mpi_point_new (0);

	brandt_assert (ad && buf);

	if (buflen != (ad->k * (2 * sizeof (struct ec_mpi) + sizeof (*proof2))))
	{
		weprintf ("wrong size of received outcome");
		goto quit;
	}

	for (uint16_t j = 0; j < ad->k; j++)
	{
		ec_point_parse (gamma, (struct ec_mpi *)cur);
		ec_point_parse (delta, &((struct ec_mpi *)cur)[1]);
		proof2 = (struct proof_2dle *)(cur + 2 * sizeof (struct ec_mpi));
		if (smc_zkp_2dle_check (gamma,
		                        delta,
		                        ad->gamma2[sender][j],
		                        ad->delta2[sender][j],
		                        proof2))
		{
			weprintf ("wrong zkp2 for gamma, delta received");
			goto quit;
		}
		ec_point_copy (ad->gamma2[sender][j], gamma);
		ec_point_copy (ad->delta2[sender][j], delta);

		/* add winner determination summand */
		gcry_mpi_ec_add (ad->gamma2[sender][j],
		                 ad->gamma2[sender][j],
		                 ad->tmpa1[j],
		                 ec_ctx);
		gcry_mpi_ec_add (ad->delta2[sender][j],
		                 ad->delta2[sender][j],
		                 ad->tmpb1[j],
		                 ec_ctx);

		cur += 2 * sizeof (struct ec_mpi) + sizeof (*proof2);
	}

	ret = 1;
quit:
	gcry_mpi_point_release (gamma);
	gcry_mpi_point_release (delta);
	return ret;
}


void
fp_pub_prep_decryption (struct BRANDT_Auction *ad)
{
	gcry_mpi_point_t tmp = gcry_mpi_point_new (0);

	ad->phi2 = smc_init2 (ad->n, ad->k);
	brandt_assert (ad->phi2);

	for (uint16_t j = 0; j < ad->k; j++)
	{
		smc_sum (tmp, &ad->delta2[0][j], ad->n, ad->k);

		/* copy still encrypted outcome to all other bidder layers so they
		 * don't have to be recomputed to check the ZK proof_2dle's from
		 * other bidders when receiving their outcome decryption messages */
		for (uint16_t a = 0; a < ad->n; a++)
			ec_point_copy (ad->phi2[a][j], tmp);
	}

	gcry_mpi_point_release (tmp);
}


/**
 * fp_pub_decrypt_outcome decrypts part of the outcome and packs it into a
 * message buffer together with proofs of correctnes.
 *
 * @param[in] ad Pointer to the BRANDT_Auction struct to operate on
 * @param[out] buflen Size of the returned message buffer in bytes
 * @return A buffer containing the own share of the decrypted outcome
 * which needs to be broadcast
 */
unsigned char *
fp_pub_decrypt_outcome (struct BRANDT_Auction *ad, size_t *buflen)
{
	unsigned char     *ret;
	unsigned char     *cur;
	gcry_mpi_point_t  tmp = gcry_mpi_point_new (0);
	struct msg_head   *head;
	struct ec_mpi     *phi;
	struct proof_2dle *proof2;

	brandt_assert (ad && buflen);

	*buflen = (sizeof (*head) + ad->k * (sizeof (*phi) + sizeof (*proof2)));
	ret = GNUNET_new_array (*buflen, unsigned char);

	head = (struct msg_head *)ret;
	head->prot_version = htonl (0);
	head->msg_type = htonl (msg_decrypt);
	cur = ret + sizeof (*head);

	for (uint16_t j = 0; j < ad->k; j++)
	{
		phi = (struct ec_mpi *)cur;
		proof2 = (struct proof_2dle *)(cur + sizeof (*phi));

		ec_point_copy (tmp, ad->phi2[ad->i][j]);

		/* decrypt outcome component and prove the correct key was used */
		smc_zkp_2dle (ad->phi2[ad->i][j],
		              NULL,
		              tmp,
		              ec_gen,
		              ad->x,
		              proof2);

		ec_point_serialize (phi, ad->phi2[ad->i][j]);

		cur += sizeof (*phi) + sizeof (*proof2);
	}

	gcry_mpi_point_release (tmp);
	return ret;
}


int
fp_pub_recv_decryption (struct BRANDT_Auction *ad,
                        const unsigned char   *buf,
                        size_t                buflen,
                        uint16_t              sender)
{
	int                 ret = 0;
	const unsigned char *cur = buf;
	struct proof_2dle   *proof2;
	gcry_mpi_point_t    phi = gcry_mpi_point_new (0);

	brandt_assert (ad && buf);

	if (buflen != (ad->k * (sizeof (struct ec_mpi) + sizeof (*proof2))))
	{
		weprintf ("wrong size of received outcome decryption");
		goto quit;
	}

	for (uint16_t j = 0; j < ad->k; j++)
	{
		ec_point_parse (phi, (struct ec_mpi *)cur);
		proof2 = (struct proof_2dle *)(cur + sizeof (struct ec_mpi));
		if (smc_zkp_2dle_check (phi,
		                        ad->y[sender],
		                        ad->phi2[sender][j],
		                        ec_gen,
		                        proof2))
		{
			weprintf ("wrong zkp2 for phi, y received");
			goto quit;
		}
		ec_point_copy (ad->phi2[sender][j], phi);
		cur += sizeof (struct ec_mpi) + sizeof (*proof2);
	}

	ret = 1;
quit:
	gcry_mpi_point_release (phi);
	return ret;
}


struct BRANDT_Result *
fp_pub_determine_outcome (struct BRANDT_Auction *ad,
                          uint16_t              *len)
{
	struct BRANDT_Result *ret;
	int32_t              price = -1;
	int32_t              winner = -1;
	int                  dlogi = -1;
	gcry_mpi_point_t     sum_gamma = gcry_mpi_point_new (0);
	gcry_mpi_point_t     sum_phi = gcry_mpi_point_new (0);

	brandt_assert (ad);

	for (uint16_t j = ad->k - 1; j >= 0; j--)
	{
		smc_sum (sum_gamma, &ad->gamma2[0][j], ad->n, ad->k);
		smc_sum (sum_phi, &ad->phi2[0][j], ad->n, ad->k);
		gcry_mpi_ec_sub (sum_gamma, sum_gamma, sum_phi, ec_ctx);
		/* first non-zero component determines the price */
		if (ec_point_cmp (sum_gamma, ec_zero))
		{
			price = j;
			break;
		}
	}

	dlogi = GNUNET_CRYPTO_ecc_dlog (ec_dlogctx, sum_gamma);
	brandt_assert (dlogi > 0);

	/* all bidders participated with a multiplicative share */
	dlogi /= ad->n;

	/* can only support up to bits(dlogi) bidders */
	brandt_assert (sizeof (int) * 8 > ad->n);
	for (uint16_t i = 0; i < ad->n; i++)
	{
		/* first set bit determines the winner */
		if (dlogi & (1 << i))
		{
			winner = i;
			break;
		}
	}

	gcry_mpi_point_release (sum_gamma);
	gcry_mpi_point_release (sum_phi);

	if (-1 == winner || -1 == price)
		return NULL;

	ret = GNUNET_new (struct BRANDT_Result);
	ret->bidder = winner;
	ret->price = price;
	ret->status = BRANDT_bidder_won;
	if (len)
		*len = 1;
	return ret;
}