From a0b5aba71c6e1a9fc5fba83fea0cfb1f9fe0b9f2 Mon Sep 17 00:00:00 2001 From: Florian Dold Date: Sat, 14 Dec 2019 19:09:01 +0100 Subject: [PATCH] allow specifying common base type for unions --- src/util/codec.ts | 48 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/src/util/codec.ts b/src/util/codec.ts index 7dcf493a2..0215ce797 100644 --- a/src/util/codec.ts +++ b/src/util/codec.ts @@ -118,10 +118,10 @@ class ObjectCodecBuilder { } } -class UnionCodecBuilder { +class UnionCodecBuilder { private alternatives = new Map(); - constructor(private discriminator: D) {} + constructor(private discriminator: D, private baseCodec?: Codec) {} /** * Define a property for the object. @@ -129,7 +129,7 @@ class UnionCodecBuilder { alternative( tagValue: T[D], codec: Codec, - ): UnionCodecBuilder { + ): UnionCodecBuilder { this.alternatives.set(tagValue, { codec, tagValue }); return this as any; } @@ -140,21 +140,36 @@ class UnionCodecBuilder { * @param objectDisplayName name of the object that this codec operates on, * used in error messages. */ - build(objectDisplayName: string): Codec { + build(objectDisplayName: string): Codec { const alternatives = this.alternatives; const discriminator = this.discriminator; + const baseCodec = this.baseCodec; return { decode(x: any, c?: Context): R { const d = x[discriminator]; if (d === undefined) { - throw new DecodingError(`expected tag for ${objectDisplayName} at ${renderContext(c)}.${discriminator}`); + throw new DecodingError( + `expected tag for ${objectDisplayName} at ${renderContext( + c, + )}.${discriminator}`, + ); } const alt = alternatives.get(d); if (!alt) { - throw new DecodingError(`unknown tag for ${objectDisplayName} ${d} at ${renderContext(c)}.${discriminator}`); + throw new DecodingError( + `unknown tag for ${objectDisplayName} ${d} at ${renderContext( + c, + )}.${discriminator}`, + ); } - return alt.codec.decode(x); - } + const altDecoded = alt.codec.decode(x); + if (baseCodec) { + const baseDecoded = baseCodec.decode(x, c); + return { ...baseDecoded, ...altDecoded }; + } else { + return altDecoded; + } + }, }; } } @@ -180,10 +195,12 @@ export function stringConstCodec(s: V): Codec { if (x === s) { return x; } - throw new DecodingError(`expected string constant "${s}" at ${renderContext(c)}`); - } - } -}; + throw new DecodingError( + `expected string constant "${s}" at ${renderContext(c)}`, + ); + }, + }; +} /** * Return a codec for a value that must be a number. @@ -234,8 +251,11 @@ export function mapCodec(innerCodec: Codec): Codec<{ [x: string]: T }> { } export class UnionCodecPreBuilder { - discriminateOn(discriminator: D): UnionCodecBuilder { - return new UnionCodecBuilder(discriminator); + discriminateOn( + discriminator: D, + baseCodec?: Codec, + ): UnionCodecBuilder { + return new UnionCodecBuilder(discriminator, baseCodec); } }