Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 86 additions & 17 deletions pymongo/_cmessagemodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@

#include "_cbsonmodule.h"
#include "buffer.h"
#include <limits.h>

static int
_check_int32_size(size_t size, const char *what) {
if (size > (size_t)INT32_MAX) {
PyErr_Format(PyExc_OverflowError,
"MongoDB %s exceeds maximum int32 size (%d bytes)",
what, INT32_MAX);
return 0;
}
return 1;
}

struct module_state {
PyObject* _cbson;
Expand Down Expand Up @@ -80,7 +92,8 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) {
PyObject* options_obj = NULL;
codec_options_t options;
buffer_t buffer = NULL;
int length_location, message_length;
int length_location;
size_t message_length;
PyObject* result = NULL;
struct module_state *state = GETSTATE(self);
if (!state) {
Expand Down Expand Up @@ -136,10 +149,18 @@ static PyObject* _cbson_query_message(PyObject* self, PyObject* args) {
max_size = (cur_size > max_size) ? cur_size : max_size;
}

message_length = pymongo_buffer_get_position(buffer) - length_location;
message_length =
(size_t)pymongo_buffer_get_position(buffer) -
(size_t)length_location;

if (!_check_int32_size(message_length, "message length")) {
goto fail;
}

buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length);


/* objectify buffer */
result = Py_BuildValue("iy#i", request_id,
pymongo_buffer_get_buffer(buffer),
Expand All @@ -162,7 +183,8 @@ static PyObject* _cbson_get_more_message(PyObject* self, PyObject* args) {
int num_to_return;
long long cursor_id;
buffer_t buffer = NULL;
int length_location, message_length;
int length_location;
size_t message_length;
PyObject* result = NULL;

if (!PyArg_ParseTuple(args, "et#iL",
Expand Down Expand Up @@ -196,7 +218,14 @@ static PyObject* _cbson_get_more_message(PyObject* self, PyObject* args) {
goto fail;
}

message_length = pymongo_buffer_get_position(buffer) - length_location;
message_length =
(size_t)pymongo_buffer_get_position(buffer) -
(size_t)length_location;

if (!_check_int32_size(message_length, "getMore message length")) {
goto fail;
}

buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length);

Expand Down Expand Up @@ -229,7 +258,8 @@ static PyObject* _cbson_op_msg(PyObject* self, PyObject* args) {
PyObject* options_obj = NULL;
codec_options_t options;
buffer_t buffer = NULL;
int length_location, message_length;
int length_location;
size_t message_length;
int total_size = 0;
int max_doc_size = 0;
PyObject* result = NULL;
Expand Down Expand Up @@ -279,7 +309,8 @@ static PyObject* _cbson_op_msg(PyObject* self, PyObject* args) {
}

if (identifier_length) {
int payload_one_length_location, payload_length;
int payload_one_length_location;
size_t payload_length;
/* Payload type 1 */
if (!buffer_write_bytes(buffer, "\x01", 1)) {
goto fail;
Expand Down Expand Up @@ -307,16 +338,32 @@ static PyObject* _cbson_op_msg(PyObject* self, PyObject* args) {
Py_CLEAR(doc);
}

payload_length = pymongo_buffer_get_position(buffer) - payload_one_length_location;
payload_length =
(size_t)pymongo_buffer_get_position(buffer) -
(size_t)payload_one_length_location;

if (!_check_int32_size(payload_length, "OP_MSG payload length")) {
goto fail;
}

buffer_write_int32_at_position(
buffer, payload_one_length_location, (int32_t)payload_length);
total_size += payload_length;
}

message_length = pymongo_buffer_get_position(buffer) - length_location;

message_length =
(size_t)pymongo_buffer_get_position(buffer) -
(size_t)length_location;

if (!_check_int32_size(message_length, "OP_MSG message length")) {
goto fail;
}

buffer_write_int32_at_position(
buffer, length_location, (int32_t)message_length);


/* objectify buffer */
result = Py_BuildValue("iy#ii", request_id,
pymongo_buffer_get_buffer(buffer),
Expand Down Expand Up @@ -365,8 +412,8 @@ _batched_op_msg(
long max_message_size;
int idx = 0;
int size_location;
int position;
int length;
size_t position;
size_t length;
PyObject* max_bson_size_obj = NULL;
PyObject* max_write_batch_size_obj = NULL;
PyObject* max_message_size_obj = NULL;
Expand Down Expand Up @@ -520,8 +567,13 @@ _batched_op_msg(
goto fail;
}

position = pymongo_buffer_get_position(buffer);
length = position - size_location;
position = (size_t)pymongo_buffer_get_position(buffer);
length = position - (size_t)size_location;

if (!_check_int32_size(length, "batched OP_MSG section length")) {
goto fail;
}

buffer_write_int32_at_position(buffer, size_location, (int32_t)length);
return 1;

Expand Down Expand Up @@ -591,7 +643,7 @@ _cbson_batched_op_msg(PyObject* self, PyObject* args) {
unsigned char op;
unsigned char ack;
int request_id;
int position;
size_t position;
PyObject* command = NULL;
PyObject* docs = NULL;
PyObject* ctx = NULL;
Expand Down Expand Up @@ -643,7 +695,12 @@ _cbson_batched_op_msg(PyObject* self, PyObject* args) {
}

request_id = rand();
position = pymongo_buffer_get_position(buffer);
position = (size_t)pymongo_buffer_get_position(buffer);

if (!_check_int32_size(position, "batched OP_MSG message length")) {
goto fail;
}

buffer_write_int32_at_position(buffer, 0, (int32_t)position);
buffer_write_int32_at_position(buffer, 4, (int32_t)request_id);
result = Py_BuildValue("iy#O", request_id,
Expand All @@ -657,6 +714,7 @@ _cbson_batched_op_msg(PyObject* self, PyObject* args) {
return result;
}


/* End OP_MSG -------------------------------------------- */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/* End OP_MSG -------------------------------------------- */
/* End OP_MSG -------------------------------------------- */


static int
Expand Down Expand Up @@ -850,10 +908,21 @@ _batched_write_command(
goto fail;
}

position = pymongo_buffer_get_position(buffer);
length = position - lst_len_loc - 1;
position = (size_t)pymongo_buffer_get_position(buffer);
length = position - (size_t)lst_len_loc - 1;

if (!_check_int32_size(length, "batched write list length")) {
goto fail;
}

buffer_write_int32_at_position(buffer, lst_len_loc, (int32_t)length);
length = position - cmd_len_loc;

length = position - (size_t)cmd_len_loc;

if (!_check_int32_size(length, "batched write command length")) {
goto fail;
}

buffer_write_int32_at_position(buffer, cmd_len_loc, (int32_t)length);
return 1;

Expand Down
Loading