diff --git a/modules/sdk-core/src/bitgo/utils/tss/ecdsa/ecdsaMPCv2.ts b/modules/sdk-core/src/bitgo/utils/tss/ecdsa/ecdsaMPCv2.ts index 49980335e3..78619bd805 100644 --- a/modules/sdk-core/src/bitgo/utils/tss/ecdsa/ecdsaMPCv2.ts +++ b/modules/sdk-core/src/bitgo/utils/tss/ecdsa/ecdsaMPCv2.ts @@ -1120,7 +1120,12 @@ export class EcdsaMPCv2Utils extends BaseEcdsaUtils { const userSignerBroadcastMsg1 = await userSigner.init(); const signatureShareRound1 = await getSignatureShareRoundOne(userSignerBroadcastMsg1, userGpgKey); const session = userSigner.getSession(); - const encryptedRound1Session = this.bitgo.encrypt({ input: session, password: walletPassphrase, adata }); + const sessionAdata = `${adata}:${DklsTypes.DsgState[DklsTypes.DsgState.Round1]}`; + const encryptedRound1Session = this.bitgo.encrypt({ + input: session, + password: walletPassphrase, + adata: sessionAdata, + }); const userGpgPubKey = userGpgKey.publicKey; const encryptedUserGpgPrvKey = this.bitgo.encrypt({ @@ -1170,10 +1175,11 @@ export class EcdsaMPCv2Utils extends BaseEcdsaUtils { const round1Session = this.bitgo.decrypt({ input: encryptedRound1Session, password: walletPassphrase }); - this.validateAdata(adata, encryptedRound1Session); + const round1SessionAdata = `${adata}:${DklsTypes.DsgState[DklsTypes.DsgState.Round1]}`; + this.validateAdata(round1SessionAdata, encryptedRound1Session); const userKeyShare = Buffer.from(prv, 'base64'); const userSigner = new DklsDsg.Dsg(userKeyShare, 0, derivationPath, hashBuffer); - await userSigner.setSession(round1Session); + await userSigner.setSession(round1Session, DklsTypes.DsgState.Round1); const deserializedMessages = DklsTypes.deserializeMessages(serializedBitGoToUserMessagesRound1); const userToBitGoMessagesRound2 = userSigner.handleIncomingMessages({ @@ -1191,7 +1197,13 @@ export class EcdsaMPCv2Utils extends BaseEcdsaUtils { bitgoGpgKey ); const session = userSigner.getSession(); - const encryptedRound2Session = this.bitgo.encrypt({ input: session, password: walletPassphrase, adata }); + // After two rounds of handleIncomingMessages, the session is in DsgState.Round3 (WaitMsg3). + const round3SessionAdata = `${adata}:${DklsTypes.DsgState[DklsTypes.DsgState.Round3]}`; + const encryptedRound2Session = this.bitgo.encrypt({ + input: session, + password: walletPassphrase, + adata: round3SessionAdata, + }); return { signatureShareRound2, @@ -1242,10 +1254,11 @@ export class EcdsaMPCv2Utils extends BaseEcdsaUtils { }); const round2Session = this.bitgo.decrypt({ input: encryptedRound2Session, password: walletPassphrase }); - this.validateAdata(adata, encryptedRound2Session); + const round3SessionAdata = `${adata}:${DklsTypes.DsgState[DklsTypes.DsgState.Round3]}`; + this.validateAdata(round3SessionAdata, encryptedRound2Session); const userKeyShare = Buffer.from(prv, 'base64'); const userSigner = new DklsDsg.Dsg(userKeyShare, 0, derivationPath, hashBuffer); - await userSigner.setSession(round2Session); + await userSigner.setSession(round2Session, DklsTypes.DsgState.Round3); const userToBitGoMessagesRound4 = userSigner.handleIncomingMessages({ p2pMessages: deserializedBitGoToUserMessagesRound3.p2pMessages, diff --git a/modules/sdk-lib-mpc/src/tss/ecdsa-dkls/dsg.ts b/modules/sdk-lib-mpc/src/tss/ecdsa-dkls/dsg.ts index bbcb36751b..81d51b10bb 100644 --- a/modules/sdk-lib-mpc/src/tss/ecdsa-dkls/dsg.ts +++ b/modules/sdk-lib-mpc/src/tss/ecdsa-dkls/dsg.ts @@ -88,8 +88,9 @@ export class Dsg { /** * Sets the DSG session from a base64 string. * @param {string} session - base64 string of the DSG session + * @param {DsgState} [expectedRound] - if provided, the session's round must match this value */ - async setSession(session: string): Promise { + async setSession(session: string, expectedRound?: DsgState): Promise { this.dsgSession = undefined; if (!this.dklsWasm) { await this.loadDklsWasm(); @@ -112,6 +113,9 @@ export class Dsg { default: throw Error(`Invalid State: ${round}`); } + if (expectedRound !== undefined && this.dsgState !== expectedRound) { + throw Error(`Session round mismatch: expected ${DsgState[expectedRound]}, got ${DsgState[this.dsgState]}`); + } this.dsgSessionBytes = sessionBytes; } diff --git a/modules/sdk-lib-mpc/test/unit/tss/ecdsa/dklsDsg.ts b/modules/sdk-lib-mpc/test/unit/tss/ecdsa/dklsDsg.ts index a0c6e6e35c..8fc3d157aa 100644 --- a/modules/sdk-lib-mpc/test/unit/tss/ecdsa/dklsDsg.ts +++ b/modules/sdk-lib-mpc/test/unit/tss/ecdsa/dklsDsg.ts @@ -12,6 +12,7 @@ import { DeserializedBroadcastMessage, DeserializedDklsSignature, DeserializedMessages, + DsgState, getDecodedReducedKeyShare, ReducedKeyShare, RetrofitData, @@ -409,4 +410,46 @@ describe('DKLS Dsg 2x3', function () { should.exist(convertedSignature); convertedSignature.split(':').length.should.equal(4); }); + + it('should succeed when setSession is called with correct expectedRound', async function () { + const vector = vectors[0]; + const party1 = new DklsDsg.Dsg( + fs.readFileSync(shareFiles[vector.party1]), + vector.party1, + vector.derivationPath, + crypto.createHash('sha256').update(Buffer.from(vector.msgToSign, 'hex')).digest() + ); + await party1.init(); + const round1Session = party1.getSession(); + + const party1Restored = new DklsDsg.Dsg( + fs.readFileSync(shareFiles[vector.party1]), + vector.party1, + vector.derivationPath, + crypto.createHash('sha256').update(Buffer.from(vector.msgToSign, 'hex')).digest() + ); + await party1Restored.setSession(round1Session, DsgState.Round1); + }); + + it('should throw when setSession is called with wrong expectedRound', async function () { + const vector = vectors[0]; + const party1 = new DklsDsg.Dsg( + fs.readFileSync(shareFiles[vector.party1]), + vector.party1, + vector.derivationPath, + crypto.createHash('sha256').update(Buffer.from(vector.msgToSign, 'hex')).digest() + ); + await party1.init(); + const round1Session = party1.getSession(); + + const party1Restored = new DklsDsg.Dsg( + fs.readFileSync(shareFiles[vector.party1]), + vector.party1, + vector.derivationPath, + crypto.createHash('sha256').update(Buffer.from(vector.msgToSign, 'hex')).digest() + ); + await party1Restored + .setSession(round1Session, DsgState.Round2) + .should.be.rejectedWith('Session round mismatch: expected Round2, got Round1'); + }); });