Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/keyring-controller/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add `withController` action to run atomic operations on multiple keyrings (within a single transaction) ([#8416](https://github.com/MetaMask/core/pull/8416))
- This action takes a subset of the controllr (a `RestrictedController` object) that exposes `addNewKeyring` and `removeKeyring` methods to add and remove keyring during that transaction call.
- Expose `KeyringController:signTransaction` method through `KeyringController` messenger ([#8408](https://github.com/MetaMask/core/pull/8408))

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,24 @@ export type KeyringControllerWithKeyringAction = {
handler: KeyringController['withKeyring'];
};

/**
* Execute an operation against all keyrings as a mutually exclusive atomic
* operation. The operation receives a {@link RestrictedController} instance
* that exposes a read-only live view of all keyrings as well as
* `addNewKeyring` and `removeKeyring` methods to stage mutations.
*
* The method automatically persists changes at the end of the function
* execution, or rolls back the changes if an error is thrown.
*
* @param operation - Function to execute with the restricted controller.
* @returns Promise resolving to the result of the function execution.
* @template CallbackResult - The type of the value resolved by the callback function.
*/
export type KeyringControllerWithControllerAction = {
type: `KeyringController:withController`;
handler: KeyringController['withController'];
};

/**
* Select a keyring and execute the given operation with the selected
* keyring, **without** acquiring the controller's mutual exclusion lock.
Expand Down Expand Up @@ -334,4 +352,5 @@ export type KeyringControllerMethodActions =
| KeyringControllerPatchUserOperationAction
| KeyringControllerSignUserOperationAction
| KeyringControllerWithKeyringAction
| KeyringControllerWithControllerAction
| KeyringControllerWithKeyringUnsafeAction;
257 changes: 257 additions & 0 deletions packages/keyring-controller/src/KeyringController.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3992,6 +3992,241 @@ describe('KeyringController', () => {
});
});

describe('withController', () => {
it('throws if the controller is locked', async () => {
await withController(
{ skipVaultCreation: true },
async ({ controller }) => {
await expect(controller.withController(jest.fn())).rejects.toThrow(
KeyringControllerErrorMessage.ControllerLocked,
);
},
);
});

it('provides the current keyrings to the callback', async () => {
await withController(async ({ controller, initialState }) => {
await controller.withController(async (restrictedController) => {
expect(restrictedController.keyrings).toHaveLength(1);
expect(restrictedController.keyrings[0].metadata).toStrictEqual(
initialState.keyrings[0].metadata,
);
});
});
});

it('returns the result of the callback', async () => {
await withController(async ({ controller }) => {
const result = await controller.withController(async () => 'hello');
expect(result).toBe('hello');
});
});

it('throws if the callback returns a raw keyring instance', async () => {
await withController(async ({ controller }) => {
await expect(
controller.withController(async (restrictedController) => {
return restrictedController.keyrings[0].keyring;
}),
).rejects.toThrow(
KeyringControllerErrorMessage.UnsafeDirectKeyringAccess,
);
});
});

describe('addNewKeyring', () => {
it('creates an initialized keyring and stages it for commit', async () => {
const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4';
stubKeyringClassWithAccount(MockKeyring, mockAddress);

await withController(
{ keyringBuilders: [keyringBuilderFactory(MockKeyring)] },
async ({ controller }) => {
await controller.withController(async (restrictedController) => {
const entry = await restrictedController.addNewKeyring(
MockKeyring.type,
);

expect(entry.keyring).toBeInstanceOf(MockKeyring);
expect(entry.metadata.id).toBeDefined();
});

expect(controller.state.keyrings).toHaveLength(2);
},
);
});

it('appears immediately in restrictedController.keyrings', async () => {
const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4';
stubKeyringClassWithAccount(MockKeyring, mockAddress);

await withController(
{ keyringBuilders: [keyringBuilderFactory(MockKeyring)] },
async ({ controller }) => {
await controller.withController(async (restrictedController) => {
expect(restrictedController.keyrings).toHaveLength(1);
await restrictedController.addNewKeyring(MockKeyring.type);
expect(restrictedController.keyrings).toHaveLength(2);
});
},
);
});

it('destroys created keyrings and does not commit them if the callback throws', async () => {
const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4';
stubKeyringClassWithAccount(MockKeyring, mockAddress);
const destroySpy = jest
.spyOn(MockKeyring.prototype, 'destroy')
.mockResolvedValue(undefined);

await withController(
{ keyringBuilders: [keyringBuilderFactory(MockKeyring)] },
async ({ controller }) => {
await expect(
controller.withController(async (restrictedController) => {
await restrictedController.addNewKeyring(MockKeyring.type);
throw new Error('Oops');
}),
).rejects.toThrow('Oops');

expect(destroySpy).toHaveBeenCalledTimes(1);
expect(controller.state.keyrings).toHaveLength(1);
},
);
});
});

describe('removeKeyring', () => {
it('removes a keyring by id and commits the removal', async () => {
const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4';
stubKeyringClassWithAccount(MockKeyring, mockAddress);

await withController(
{ keyringBuilders: [keyringBuilderFactory(MockKeyring)] },
async ({ controller }) => {
await controller.addNewKeyring(MockKeyring.type);
const idToRemove = controller.state.keyrings[1].metadata.id;

await controller.withController(async (restrictedController) => {
await restrictedController.removeKeyring(idToRemove);
});

expect(controller.state.keyrings).toHaveLength(1);
expect(
controller.state.keyrings.find(
(k) => k.metadata.id === idToRemove,
),
).toBeUndefined();
},
);
});

it('disappears from restrictedController.keyrings immediately', async () => {
const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4';
stubKeyringClassWithAccount(MockKeyring, mockAddress);

await withController(
{ keyringBuilders: [keyringBuilderFactory(MockKeyring)] },
async ({ controller }) => {
await controller.addNewKeyring(MockKeyring.type);
const idToRemove = controller.state.keyrings[1].metadata.id;

await controller.withController(async (restrictedController) => {
expect(restrictedController.keyrings).toHaveLength(2);
await restrictedController.removeKeyring(idToRemove);
expect(restrictedController.keyrings).toHaveLength(1);
});
},
);
});

it('destroys the removed keyring', async () => {
const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4';
stubKeyringClassWithAccount(MockKeyring, mockAddress);
const destroySpy = jest
.spyOn(MockKeyring.prototype, 'destroy')
.mockResolvedValue(undefined);

await withController(
{ keyringBuilders: [keyringBuilderFactory(MockKeyring)] },
async ({ controller }) => {
await controller.addNewKeyring(MockKeyring.type);
const idToRemove = controller.state.keyrings[1].metadata.id;

await controller.withController(async (restrictedController) => {
await restrictedController.removeKeyring(idToRemove);
});

expect(destroySpy).toHaveBeenCalledTimes(1);
},
);
});

it('throws KeyringNotFound for an unknown id', async () => {
await withController(async ({ controller }) => {
await expect(
controller.withController(async (restrictedController) => {
await restrictedController.removeKeyring('non-existent-id');
}),
).rejects.toThrow(KeyringControllerErrorMessage.KeyringNotFound);
});
});

it('destroys a keyring that was created then removed within the same callback', async () => {
const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4';
stubKeyringClassWithAccount(MockKeyring, mockAddress);
const destroySpy = jest
.spyOn(MockKeyring.prototype, 'destroy')
.mockResolvedValue(undefined);

await withController(
{ keyringBuilders: [keyringBuilderFactory(MockKeyring)] },
async ({ controller }) => {
await controller.withController(async (restrictedController) => {
const { metadata } = await restrictedController.addNewKeyring(
MockKeyring.type,
);
await restrictedController.removeKeyring(metadata.id);
});

expect(destroySpy).toHaveBeenCalledTimes(1);
expect(controller.state.keyrings).toHaveLength(1);
},
);
});
});

it('rolls back on error', async () => {
await withController(async ({ controller, initialState }) => {
await expect(
controller.withController(async (restrictedController) => {
await restrictedController.addNewKeyring(KeyringTypes.simple);
throw new Error('Oops');
}),
).rejects.toThrow('Oops');

expect(controller.state.keyrings).toHaveLength(
initialState.keyrings.length,
);
expect(await controller.getAccounts()).toStrictEqual(
initialState.keyrings[0].accounts,
);
});
});

it('does not update the vault if no keyrings change', async () => {
await withController(async ({ controller, encryptor }) => {
const encryptSpy = jest.spyOn(encryptor, 'encrypt');

await controller.withController(async () => {
// no-op
});

expect(encryptSpy).not.toHaveBeenCalled();
});
});
});

describe('withKeyringUnsafe', () => {
it('calls the given function without acquiring the lock', async () => {
await withController(async ({ controller }) => {
Expand Down Expand Up @@ -4505,6 +4740,28 @@ describe('KeyringController', () => {
});
});

describe('withController', () => {
it('should call withController', async () => {
await withController(async ({ messenger }) => {
const operation = jest.fn().mockResolvedValue('result');

const actionReturnValue = await messenger.call(
'KeyringController:withController',
operation,
);

expect(operation).toHaveBeenCalledWith(
expect.objectContaining({
keyrings: expect.any(Array),
addNewKeyring: expect.any(Function),
removeKeyring: expect.any(Function),
}),
);
expect(actionReturnValue).toBe('result');
});
});
});

describe('addNewKeyring', () => {
it('should call addNewKeyring', async () => {
const mockKeyringMetadata: KeyringMetadata = {
Expand Down
Loading
Loading