diff --git a/plots/json2csv.py b/plots/json2csv.py index 0de1306..25fc9fc 100644 --- a/plots/json2csv.py +++ b/plots/json2csv.py @@ -14,6 +14,7 @@ # val_loss param1 param2 ... import logging +import csv # Set PYTHONPATH=$PWD from plottools import * @@ -32,47 +33,37 @@ logging.info("Found %i directories." % len(rundirs)) # The CSV file -fp_out = open(output_file, "w") - -# Write CSV header -fp_out.write("val_loss,") -header = ",".join(selected) -fp_out.write(header) -fp_out.write("\n") - -def write_values(fp, val_loss, D, selected): - # I think we have to do this for consistent ordering - L = [ str(val_loss) ] + [ D[param] for param in selected ] - # print(L) - fp.write(",".join(L)) - fp.write("\n") - -for rundir in rundirs: - - Js = get_jsons(rundir) - if len(Js) == 0: - continue - - # Get parameters from the first JSON file - record_start = Js[0][0] - params = record_start["parameters"] - D = {} - for entry in params: - tokens = entry.split(":") - param = tokens[0] - if param in selected: - value = tokens[1].strip() - D[param] = value - - # Get minimum val_loss in the directory - val_losses = [] - for J in Js: - record_count = len(J) - record_penult = J[record_count-2] - val_losses.append(record_penult["validation_loss"]["set"]) - val_loss = min(val_losses) - - write_values(fp_out, val_loss, D, selected) - -fp_out.close() +with open(output_file, "w") as fp_out: + fieldnames = ["val_loss"] + selected + writer = csv.DictWriter(fp_out, fieldnames=fieldnames) + writer.writeheader() + + for rundir in rundirs: + + Js = get_jsons(rundir) + if len(Js) == 0: + continue + + # Get parameters from the first JSON file + record_start = Js[0][0] + params = record_start["parameters"] + D = {} + for entry in params: + tokens = entry.split(":") + param = tokens[0] + if param in selected: + # re-join tail e.g. ['data_url', 'ftp', '//ftp.mcs...'] + value = ":".join(tokens[1:]).strip() + D[param] = value + + # Get minimum val_loss in the directory + val_losses = [] + for J in Js: + record_penult = J[-2] + val_losses.append(record_penult["validation_loss"]["set"]) + val_loss = min(val_losses) + D["val_loss"] = val_loss + + writer.writerow(D) + logging.info("Wrote %s ." % output_file)