Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -735,25 +735,31 @@ def denote_data_flow_for_step(
)
)
input_dimensionality_offsets = manifest.get_input_dimensionality_offsets()
print("input_dimensionality_offsets", input_dimensionality_offsets)
verify_step_input_dimensionality_offsets(
step_name=step_name,
input_dimensionality_offsets=input_dimensionality_offsets,
)
print("scalar_parameters_to_be_batched", scalar_parameters_to_be_batched)
inputs_dimensionalities = get_inputs_dimensionalities(
step_name=step_name,
step_type=manifest.type,
input_data=input_data,
scalar_parameters_to_be_batched=scalar_parameters_to_be_batched,
input_dimensionality_offsets=input_dimensionality_offsets,
)
print("inputs_dimensionalities", inputs_dimensionalities)
logger.debug(
f"For step: {node}, detected the following input dimensionalities: {inputs_dimensionalities}"
)
parameters_with_batch_inputs = grab_parameters_defining_batch_inputs(
inputs_dimensionalities=inputs_dimensionalities,
)
print("parameters_with_batch_inputs", parameters_with_batch_inputs)
dimensionality_reference_property = manifest.get_dimensionality_reference_property()
print("dimensionality_reference_property", dimensionality_reference_property)
output_dimensionality_offset = manifest.get_output_dimensionality_offset()
print("output_dimensionality_offset", output_dimensionality_offset)
verify_step_input_dimensionality_offsets(
step_name=step_name,
input_dimensionality_offsets=input_dimensionality_offsets,
Expand Down Expand Up @@ -812,6 +818,8 @@ def denote_data_flow_for_step(
scalar_parameters_to_be_batched=scalar_parameters_to_be_batched,
)
step_node_data.auto_batch_casting_lineage_supports = lineage_supports
print("lineage_supports", lineage_supports)
print("Data lineage of block output", data_lineage)
if data_lineage:
on_top_level_lineage_denoted(data_lineage[0])
step_node_data.data_lineage = data_lineage
Expand Down Expand Up @@ -1563,10 +1571,10 @@ def retrieve_batch_compatibility_of_input_selectors(
) -> Dict[str, Set[bool]]:
batch_compatibility_of_properties = defaultdict(set)
for parsed_selector in input_selectors:
property_name = parsed_selector.definition.property_name
target_set = batch_compatibility_of_properties[property_name]
for reference in parsed_selector.definition.allowed_references:
batch_compatibility_of_properties[
parsed_selector.definition.property_name
].update(reference.points_to_batch)
target_set |= reference.points_to_batch
return batch_compatibility_of_properties


Expand Down Expand Up @@ -1606,6 +1614,9 @@ def verify_declared_batch_compatibility_against_actual_inputs(
)
if batch_compatibility == {True} and False in actual_input_is_batch:
scalar_parameters_to_be_batched.add(property_name)
print(
f"property_name: {property_name}, batch_compatibility={batch_compatibility}, actual_input_is_batch={actual_input_is_batch}, step_accepts_batch_input={step_accepts_batch_input}"
)
return scalar_parameters_to_be_batched


Expand Down Expand Up @@ -1654,6 +1665,7 @@ def get_lineage_support_for_auto_batch_casted_parameters(
casted_dimensionality=parameter_dimensionality,
lineage_support=lineage_support,
)
print("DUMMY", result)
return result


Expand Down