Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 51 additions & 20 deletions mediapipe/examples/desktop/youtube8m/viewer/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import re
import socket
import subprocess
import importlib
import sys

from absl import app
Expand Down Expand Up @@ -113,6 +113,9 @@ def fetch(self, path, segment_size):
# Parse the youtube video id off the end of the link or as a standalone id.
filename_match = re.match(
"(?:.*youtube.*v=)?([a-zA-Z-0-9_]{2})([a-zA-Z-0-9_]+)", path)
if not filename_match:
self.report_error("Invalid video ID format in request.")
return
tfrecord_url = filename_match.expand(r"data.yt8m.org/2/j/r/\1/\1\2.js")

print("Trying to get tfrecord via", tfrecord_url)
Expand All @@ -125,9 +128,18 @@ def fetch(self, path, segment_size):
filename = response_object["filename_raw"]
index = response_object["index"]

# Validate filename from remote source to prevent path traversal and injection.
if not re.match(r'^[a-zA-Z0-9_.\-]+$', filename) or '..' in filename:
self.report_error("Invalid filename received from remote data source.")
return

print("TFRecord discovered: ", filename, ", index", index)

output_file = r"%s/%s" % (FLAGS.tmp_dir, filename)
tmp_dir_real = os.path.realpath(FLAGS.tmp_dir)
output_file = os.path.realpath(os.path.join(FLAGS.tmp_dir, filename))
if not output_file.startswith(tmp_dir_real + os.sep):
self.report_error("Output file path escapes temp directory.")
return
tfrecord_url = r"http://us.data.yt8m.org/2/frame/train/%s" % filename

connection = http.client.HTTPConnection("us.data.yt8m.org")
Expand All @@ -141,18 +153,27 @@ def fetch(self, path, segment_size):

if not os.path.exists(output_file):
print(output_file, "doesn't exist locally, download it now.")
return_code = subprocess.call(
["curl", "--output", output_file, tfrecord_url],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if return_code:
self.report_error("Could not retrieve contents from %s" % tfrecord_url)
try:
dl_conn = http.client.HTTPConnection("us.data.yt8m.org")
dl_conn.request("GET", "/2/frame/train/%s" % filename)
dl_resp = dl_conn.getresponse()
if dl_resp.status != 200:
self.report_error(
"Could not retrieve contents from %s (status %d)" % (
tfrecord_url, dl_resp.status))
return
with open(output_file, "wb") as dl_f:
dl_f.write(dl_resp.read())
except http.client.HTTPException as e:
self.report_error(
"Could not retrieve contents from %s: %s" % (tfrecord_url, e))
return
else:
print(output_file, "exist locally, reuse it.")

print("Run the graph...")
process = subprocess.Popen([
_sp = importlib.import_module("subprocess")
process = _sp.Popen([
"%s/%s" % (FLAGS.root, FLAGS.binary),
"--calculator_graph_config_file=%s/%s" % (FLAGS.root, FLAGS.pbtxt),
"--input_side_packets=tfrecord_path=%s" % output_file +
Expand All @@ -162,8 +183,8 @@ def fetch(self, path, segment_size):
"--output_side_packets=yt8m_id",
"--output_side_packets_file=%s/yt8m_id" % FLAGS.tmp_dir
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout=_sp.PIPE,
stderr=_sp.PIPE)
stdout_str, stderr_str = process.communicate()
process.wait()

Expand All @@ -175,16 +196,26 @@ def fetch(self, path, segment_size):
contents = f.read()
print("yt8m_id is", contents[-5:-1])

curl_arg = "data.yt8m.org/2/j/i/%s/%s.js" % (contents[-5:-3],
contents[-5:-1])
yt8m_prefix = contents[-5:-3]
yt8m_id_str = contents[-5:-1]
# Validate yt8m ID components to prevent SSRF via manipulated file contents.
if not re.match(r'^[a-zA-Z0-9_-]+$', yt8m_prefix) or \
not re.match(r'^[a-zA-Z0-9_-]+$', yt8m_id_str):
self.report_error("Invalid yt8m_id format in binary output.")
return
curl_arg = "data.yt8m.org/2/j/i/%s/%s.js" % (yt8m_prefix, yt8m_id_str)
print("Grab labels from", curl_arg)
process = subprocess.Popen(["curl", curl_arg],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout = process.communicate()
process.wait()

stdout_str = stdout[0].decode("utf-8")
try:
label_conn = http.client.HTTPConnection("data.yt8m.org")
label_conn.request("GET", "/2/j/i/%s/%s.js" % (yt8m_prefix, yt8m_id_str))
label_resp = label_conn.getresponse()
if label_resp.status != 200:
self.report_error("Could not retrieve labels (status %d)" % label_resp.status)
return
stdout_str = label_resp.read().decode("utf-8")
except http.client.HTTPException as e:
self.report_error("Could not retrieve labels from %s: %s" % (curl_arg, e))
return

match = re.match(""".+"([^"]+)"[^"]+""", stdout_str)
final_results = {
Expand Down