-
Notifications
You must be signed in to change notification settings - Fork 147
refactor: unify duplicate DAG construction (dag.py + ExecutionGraph) #511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e9b5b3d
eddf3a3
6b2a8c8
17e9288
b481244
2e48430
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,12 +9,15 @@ | |||||||||||||||||||||||||
| from typing import TYPE_CHECKING | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from data_designer.config.column_configs import GenerationStrategy | ||||||||||||||||||||||||||
| from data_designer.config.column_types import ColumnConfigT | ||||||||||||||||||||||||||
| from data_designer.engine.column_generators.utils.generator_classification import column_type_used_in_execution_dag | ||||||||||||||||||||||||||
| from data_designer.engine.dataset_builders.multi_column_configs import ( | ||||||||||||||||||||||||||
| DatasetBuilderColumnConfigT, | ||||||||||||||||||||||||||
| MultiColumnConfig, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError | ||||||||||||||||||||||||||
| from data_designer.engine.dataset_builders.utils.task_model import SliceRef | ||||||||||||||||||||||||||
| from data_designer.logging import LOG_INDENT | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -330,3 +333,94 @@ def to_mermaid(self) -> str: | |||||||||||||||||||||||||
| for dep in sorted(self._upstream.get(col, set())): | ||||||||||||||||||||||||||
| lines.append(f" {dep} --> {col}") | ||||||||||||||||||||||||||
| return "\n".join(lines) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _resolve_dag_column( | ||||||||||||||||||||||||||
| col_name: str, | ||||||||||||||||||||||||||
| dag_col_dict: dict[str, ColumnConfigT], | ||||||||||||||||||||||||||
| side_effect_map: dict[str, str], | ||||||||||||||||||||||||||
| ) -> str | None: | ||||||||||||||||||||||||||
| """Resolve a column name to its DAG producer. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns the column itself if it is a direct DAG column, the producing | ||||||||||||||||||||||||||
| column if it is a declared side-effect, or ``None`` if the name is not | ||||||||||||||||||||||||||
| known to this DAG (e.g. a seed or sampler column). | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| if col_name in dag_col_dict: | ||||||||||||||||||||||||||
| return col_name | ||||||||||||||||||||||||||
| return side_effect_map.get(col_name) | ||||||||||||||||||||||||||
|
Comment on lines
+349
to
+351
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method can be absorbed by |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _add_dag_edge( | ||||||||||||||||||||||||||
| name: str, | ||||||||||||||||||||||||||
| dep: str, | ||||||||||||||||||||||||||
| label: str, | ||||||||||||||||||||||||||
| dag_col_dict: dict[str, ColumnConfigT], | ||||||||||||||||||||||||||
| side_effect_map: dict[str, str], | ||||||||||||||||||||||||||
| upstream: dict[str, set[str]], | ||||||||||||||||||||||||||
| downstream: dict[str, set[str]], | ||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||
|
Comment on lines
+338
to
+362
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To follow our code style guide (public before private) these two private methods need to be pushed further down. |
||||||||||||||||||||||||||
| """Add a dependency edge from *dep*'s producer to *name* if the dep is a known DAG column. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Self-edges are skipped, consistent with ``ExecutionGraph.create``. | ||||||||||||||||||||||||||
| The *label* parameter (``"required"`` or ``"skip.when"``) is included in | ||||||||||||||||||||||||||
| the debug log so the source of each edge is visible during tracing. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| resolved = _resolve_dag_column(dep, dag_col_dict, side_effect_map) | ||||||||||||||||||||||||||
| if resolved is None or resolved == name: | ||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||
| logger.debug(f"{LOG_INDENT}🔗 `{name}` depends on `{resolved}` [{label}]") | ||||||||||||||||||||||||||
| upstream[name].add(resolved) | ||||||||||||||||||||||||||
| downstream[resolved].add(name) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]: | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This public method could use a docstring!
Suggested change
|
||||||||||||||||||||||||||
| non_dag_cols = [col for col in column_configs if not column_type_used_in_execution_dag(col.column_type)] | ||||||||||||||||||||||||||
| dag_col_dict = {col.name: col for col in column_configs if column_type_used_in_execution_dag(col.column_type)} | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if not dag_col_dict: | ||||||||||||||||||||||||||
| return non_dag_cols | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # side_effect_col_name -> producing column name | ||||||||||||||||||||||||||
| side_effect_map: dict[str, str] = {} | ||||||||||||||||||||||||||
| for name, col in dag_col_dict.items(): | ||||||||||||||||||||||||||
| for se_col in col.side_effect_columns: | ||||||||||||||||||||||||||
| existing = side_effect_map.get(se_col) | ||||||||||||||||||||||||||
| if existing is not None and existing != name: | ||||||||||||||||||||||||||
| raise ConfigCompilationError( | ||||||||||||||||||||||||||
| f"Side-effect column {se_col!r} is already produced by {existing!r}; " | ||||||||||||||||||||||||||
| f"cannot register a second producer {name!r}. " | ||||||||||||||||||||||||||
| f"Use distinct side-effect column names for each pipeline stage." | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| side_effect_map[se_col] = name | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| upstream: dict[str, set[str]] = {name: set() for name in dag_col_dict} | ||||||||||||||||||||||||||
| downstream: dict[str, set[str]] = {name: set() for name in dag_col_dict} | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| logger.info("⛓️ Sorting column configs into a Directed Acyclic Graph") | ||||||||||||||||||||||||||
| for name, col in dag_col_dict.items(): | ||||||||||||||||||||||||||
| for req in col.required_columns: | ||||||||||||||||||||||||||
| _add_dag_edge(name, req, "required", dag_col_dict, side_effect_map, upstream, downstream) | ||||||||||||||||||||||||||
| if col.skip is not None: | ||||||||||||||||||||||||||
| for skip_col in col.skip.columns: | ||||||||||||||||||||||||||
| _add_dag_edge(name, skip_col, "skip.when", dag_col_dict, side_effect_map, upstream, downstream) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| in_degree = {name: len(ups) for name, ups in upstream.items()} | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This (Kahn's algorithm) is nearly identical to get_topological_order above (lines 236-258). Could we extract a shared _kahns_topological_sort(nodes, upstream, downstream) -> list[str] helper that both call? They're in the same file and the only difference is the error message, which we could unify. This will also be helpful in a future PR where we'll want to reuse it for the DAG inside samplers. |
||||||||||||||||||||||||||
| queue: deque[str] = deque(name for name, deg in in_degree.items() if deg == 0) | ||||||||||||||||||||||||||
| order: list[str] = [] | ||||||||||||||||||||||||||
| while queue: | ||||||||||||||||||||||||||
| name = queue.popleft() | ||||||||||||||||||||||||||
| order.append(name) | ||||||||||||||||||||||||||
| for child in downstream.get(name, set()): | ||||||||||||||||||||||||||
| in_degree[child] -= 1 | ||||||||||||||||||||||||||
| if in_degree[child] == 0: | ||||||||||||||||||||||||||
| queue.append(child) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if len(order) != len(dag_col_dict): | ||||||||||||||||||||||||||
| raise DAGCircularDependencyError( | ||||||||||||||||||||||||||
| "🛑 The Data Designer column configurations contain cyclic dependencies. Please " | ||||||||||||||||||||||||||
| "inspect the column configurations and ensure they can be sorted without " | ||||||||||||||||||||||||||
| "circular references." | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return non_dag_cols + [dag_col_dict[n] for n in order] | ||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.