Skip to content

engine

django_spire.contrib.sync.database.engine

logger = logging.getLogger(__name__) module-attribute

BATCH_BYTES_DEFAULT = 2 * 1024 * 1024 module-attribute

CLOCK_DRIFT_MAX_DEFAULT = 300 module-attribute

DatabaseEngine

Source code in django_spire/contrib/sync/database/engine.py
def __init__(
    self,
    storage: DatabaseSyncStorage,
    graph: DependencyGraph,
    clock: HybridLogicalClock,
    node_id: str,
    *,
    batch_bytes: int | None = BATCH_BYTES_DEFAULT,
    batch_size: int | None = None,
    clock_drift_max: int | None = CLOCK_DRIFT_MAX_DEFAULT,
    foreign_key_columns: dict[str, list[tuple[str, str]]] | None = None,
    identity_field: str = 'id',
    lock: SyncLock | None = None,
    on_complete: Callable[[DatabaseResult], None] | None = None,
    on_phase: Callable[[SyncPhase], None] | None = None,
    payload_bytes_max: int | None = None,
    payload_records_max: int | None = None,
    peer_node_id: str | None = None,
    progress: Callable[[SyncStage, int, int], None] | None = None,
    reconciler: PayloadReconciler | None = None,
    transaction: Callable[[], AbstractContextManager[Any]] = nullcontext,
    transport: Transport | None = None,
) -> None:
    if not node_id:
        message = 'node_id must be a non-empty string'
        raise InvalidParameterError(message)

    if not identity_field:
        message = 'identity_field must be a non-empty string'
        raise InvalidParameterError(message)

    if batch_bytes is not None and batch_bytes < 1:
        message = (
            f'batch_bytes must be >= 1 '
            f'or None, got {batch_bytes}'
        )

        raise InvalidParameterError(message)

    if batch_size is not None and batch_size < 1:
        message = (
            f'batch_size must be >= 1 '
            f'or None, got {batch_size}'
        )

        raise InvalidParameterError(message)

    if clock_drift_max is not None and clock_drift_max < 0:
        message = (
            f'clock_drift_max must be non-negative '
            f'or None, got {clock_drift_max}'
        )

        raise InvalidParameterError(message)

    if payload_bytes_max is not None and payload_bytes_max < 1:
        message = (
            f'payload_bytes_max must be >= 1 '
            f'or None, got {payload_bytes_max}'
        )

        raise InvalidParameterError(message)

    if payload_records_max is not None and payload_records_max < 1:
        message = (
            f'payload_records_max must be >= 1 '
            f'or None, got {payload_records_max}'
        )

        raise InvalidParameterError(message)

    if transport is not None and not peer_node_id:
        message = (
            'peer_node_id is required when transport is set '
            '(client mode)'
        )

        raise InvalidParameterError(message)

    if peer_node_id is not None and peer_node_id == node_id:
        message = (
            f'peer_node_id must differ from node_id '
            f'(both are {node_id!r})'
        )

        raise InvalidParameterError(message)

    self._batch_bytes = batch_bytes
    self._batch_size = batch_size
    self._clock = clock
    self._clock_drift_max = clock_drift_max
    self._errored_keys: dict[str, set[str]] = defaultdict(set)
    self._foreign_key_columns = foreign_key_columns or {}
    self._graph = graph
    self._identity_field = identity_field
    self._lock = lock
    self._node_id = node_id
    self._on_complete = on_complete
    self._on_phase = on_phase
    self._payload_bytes_max = payload_bytes_max
    self._payload_records_max = payload_records_max
    self._peer_node_id = peer_node_id or ''
    self._progress = progress
    self._reconciler = reconciler or PayloadReconciler()
    self._storage = storage
    self._transaction = transaction
    self._transport = transport

process

Source code in django_spire/contrib/sync/database/engine.py
def process(
    self,
    incoming: SyncManifest,
) -> tuple[SyncManifest, DatabaseResult]:
    result = DatabaseResult()

    self._reset_errored_keys()
    self._stamp_unstamped_records()

    self._validate_manifest(incoming)
    self._validate_clock(incoming)

    valid_payloads = self._validate_incoming_models(
        incoming,
        result,
    )

    peer_node_id = incoming.node_id

    with self._transaction():
        if self._lock:
            self._lock.hold_global()
            self._lock.hold(self._node_id, peer_node_id)

        counter_at_start = (
            self._storage.get_sequence_allocator().current()
        )

        response_payloads, has_more, after_keys, response_sequence_max = (
            self._apply_incoming(
                valid_payloads,
                incoming.peer_sequence,
                peer_node_id,
                result,
                sequence_max=counter_at_start,
                records_max=self._batch_size,
                bytes_max=self._batch_bytes,
                after_keys=incoming.after_keys,
            )
        )

        self._flush_deferred_backfill()

    self._advance_clock(incoming)

    if has_more:
        outgoing_local_sequence = response_sequence_max
    else:
        outgoing_local_sequence = counter_at_start

    response = SyncManifest(
        node_id=self._node_id,
        peer_sequence=incoming.local_sequence,
        local_sequence=outgoing_local_sequence,
        after_keys=after_keys,
        node_time=int(time.time()),
        payloads=response_payloads,
        has_more=has_more,
    )

    response.checksum = response.compute_checksum()

    self._finalize(result)

    return response, result

sync

Source code in django_spire/contrib/sync/database/engine.py
def sync(self, dry_run: bool = False) -> DatabaseResult:
    if self._transport is None:
        message = (
            'Transport is required for sync(). '
            'Use process() for server-side.'
        )

        raise TransportRequiredError(message)

    if not self._peer_node_id:
        message = (
            'peer_node_id is required for sync(). '
            'Configure it on the engine.'
        )

        raise InvalidParameterError(message)

    result = DatabaseResult()

    self._reset_errored_keys()

    persisted = self._storage.get_after_keys(self._peer_node_id)

    server_cursors: dict[str, Any] = {
        key.removeprefix('server:'): value
        for key, value in persisted.items()
        if key.startswith('server:')
    }

    collect_cursors: dict[str, Any] = {
        key.removeprefix('collect:'): value
        for key, value in persisted.items()
        if key.startswith('collect:')
    }

    iteration = 0

    with self._managed_session(result) as session_id:
        self._stamp_unstamped_records()

        while True:
            iteration += 1

            self._enter_phase(
                SyncPhase.COLLECTING,
                session_id,
                SyncStage.VALIDATE,
            )

            checkpoint = self._storage.get_checkpoint(self._peer_node_id)
            peer_sequence = checkpoint.peer_sequence
            local_sequence_pushed = checkpoint.local_sequence_pushed

            counter_at_start = (
                self._storage.get_sequence_allocator().current()
            )

            manifest = self._collect(
                local_sequence_pushed,
                self._peer_node_id,
                sequence_max=counter_at_start,
                limit=self._batch_size,
                bytes_limit=self._batch_bytes,
                after_keys=collect_cursors,
            )

            collected_cursors = manifest.after_keys or {}

            manifest.peer_sequence = peer_sequence
            manifest.after_keys = server_cursors
            manifest.checksum = manifest.compute_checksum()

            sent_max_sequence = manifest.local_sequence
            sent_has_more = manifest.has_more

            self._record_pushed(manifest, result)

            self._enter_phase(
                SyncPhase.EXCHANGING,
                session_id,
                SyncStage.CLASSIFY,
            )

            response = self._exchange_and_validate(manifest)

            self._enter_phase(
                SyncPhase.RECONCILING,
                session_id,
                SyncStage.MUTATE,
            )

            if manifest.has_more:
                collect_cursors = collected_cursors
            else:
                collect_cursors = {}

            if response.has_more:
                server_cursors = response.after_keys
            else:
                server_cursors = {}

            if not dry_run:
                self._enter_phase(
                    SyncPhase.COMMITTING,
                    session_id,
                )

                self._commit(
                    peer_sequence,
                    local_sequence_pushed,
                    sent_max_sequence,
                    sent_has_more,
                    counter_at_start,
                    response,
                    result,
                    server_cursors=server_cursors,
                    collect_cursors=collect_cursors,
                )

            self._advance_clock(response)

            converged = (
                not manifest.has_more
                and not response.has_more
            )

            if dry_run:
                break

            if converged:
                exchanged = (
                    any(payload.records or payload.deletes for payload in manifest.payloads)
                    or any(payload.records or payload.deletes for payload in response.payloads)
                )

                if not exchanged:
                    break

        self._enter_phase(SyncPhase.COMPLETE, session_id)
        self._finalize(result)

    # self._log_sync_summary(result)

    return result