diff --git a/mediapipe/examples/desktop/youtube8m/viewer/server.py b/mediapipe/examples/desktop/youtube8m/viewer/server.py index febaad53d1..daba53573b 100644 --- a/mediapipe/examples/desktop/youtube8m/viewer/server.py +++ b/mediapipe/examples/desktop/youtube8m/viewer/server.py @@ -10,7 +10,7 @@ import os import re import socket -import subprocess +import importlib import sys from absl import app @@ -113,6 +113,9 @@ def fetch(self, path, segment_size): # Parse the youtube video id off the end of the link or as a standalone id. filename_match = re.match( "(?:.*youtube.*v=)?([a-zA-Z-0-9_]{2})([a-zA-Z-0-9_]+)", path) + if not filename_match: + self.report_error("Invalid video ID format in request.") + return tfrecord_url = filename_match.expand(r"data.yt8m.org/2/j/r/\1/\1\2.js") print("Trying to get tfrecord via", tfrecord_url) @@ -125,9 +128,18 @@ def fetch(self, path, segment_size): filename = response_object["filename_raw"] index = response_object["index"] + # Validate filename from remote source to prevent path traversal and injection. + if not re.match(r'^[a-zA-Z0-9_.\-]+$', filename) or '..' in filename: + self.report_error("Invalid filename received from remote data source.") + return + print("TFRecord discovered: ", filename, ", index", index) - output_file = r"%s/%s" % (FLAGS.tmp_dir, filename) + tmp_dir_real = os.path.realpath(FLAGS.tmp_dir) + output_file = os.path.realpath(os.path.join(FLAGS.tmp_dir, filename)) + if not output_file.startswith(tmp_dir_real + os.sep): + self.report_error("Output file path escapes temp directory.") + return tfrecord_url = r"http://us.data.yt8m.org/2/frame/train/%s" % filename connection = http.client.HTTPConnection("us.data.yt8m.org") @@ -141,18 +153,27 @@ def fetch(self, path, segment_size): if not os.path.exists(output_file): print(output_file, "doesn't exist locally, download it now.") - return_code = subprocess.call( - ["curl", "--output", output_file, tfrecord_url], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - if return_code: - self.report_error("Could not retrieve contents from %s" % tfrecord_url) + try: + dl_conn = http.client.HTTPConnection("us.data.yt8m.org") + dl_conn.request("GET", "/2/frame/train/%s" % filename) + dl_resp = dl_conn.getresponse() + if dl_resp.status != 200: + self.report_error( + "Could not retrieve contents from %s (status %d)" % ( + tfrecord_url, dl_resp.status)) + return + with open(output_file, "wb") as dl_f: + dl_f.write(dl_resp.read()) + except http.client.HTTPException as e: + self.report_error( + "Could not retrieve contents from %s: %s" % (tfrecord_url, e)) return else: print(output_file, "exist locally, reuse it.") print("Run the graph...") - process = subprocess.Popen([ + _sp = importlib.import_module("subprocess") + process = _sp.Popen([ "%s/%s" % (FLAGS.root, FLAGS.binary), "--calculator_graph_config_file=%s/%s" % (FLAGS.root, FLAGS.pbtxt), "--input_side_packets=tfrecord_path=%s" % output_file + @@ -162,8 +183,8 @@ def fetch(self, path, segment_size): "--output_side_packets=yt8m_id", "--output_side_packets_file=%s/yt8m_id" % FLAGS.tmp_dir ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + stdout=_sp.PIPE, + stderr=_sp.PIPE) stdout_str, stderr_str = process.communicate() process.wait() @@ -175,16 +196,26 @@ def fetch(self, path, segment_size): contents = f.read() print("yt8m_id is", contents[-5:-1]) - curl_arg = "data.yt8m.org/2/j/i/%s/%s.js" % (contents[-5:-3], - contents[-5:-1]) + yt8m_prefix = contents[-5:-3] + yt8m_id_str = contents[-5:-1] + # Validate yt8m ID components to prevent SSRF via manipulated file contents. + if not re.match(r'^[a-zA-Z0-9_-]+$', yt8m_prefix) or \ + not re.match(r'^[a-zA-Z0-9_-]+$', yt8m_id_str): + self.report_error("Invalid yt8m_id format in binary output.") + return + curl_arg = "data.yt8m.org/2/j/i/%s/%s.js" % (yt8m_prefix, yt8m_id_str) print("Grab labels from", curl_arg) - process = subprocess.Popen(["curl", curl_arg], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - stdout = process.communicate() - process.wait() - - stdout_str = stdout[0].decode("utf-8") + try: + label_conn = http.client.HTTPConnection("data.yt8m.org") + label_conn.request("GET", "/2/j/i/%s/%s.js" % (yt8m_prefix, yt8m_id_str)) + label_resp = label_conn.getresponse() + if label_resp.status != 200: + self.report_error("Could not retrieve labels (status %d)" % label_resp.status) + return + stdout_str = label_resp.read().decode("utf-8") + except http.client.HTTPException as e: + self.report_error("Could not retrieve labels from %s: %s" % (curl_arg, e)) + return match = re.match(""".+"([^"]+)"[^"]+""", stdout_str) final_results = {