Skip to content

engine

django_spire.contrib.sync.database.engine

logger = logging.getLogger(__name__) module-attribute

BATCH_BYTES_DEFAULT = 16 * 1024 * 1024 module-attribute

CLOCK_DRIFT_MAX_DEFAULT = 300 module-attribute

DatabaseEngine

Source code in django_spire/contrib/sync/database/engine.py
def __init__(
    self,
    storage: DatabaseSyncStorage,
    graph: DependencyGraph,
    clock: HybridLogicalClock,
    node_id: str,
    *,
    batch_bytes: int | None = BATCH_BYTES_DEFAULT,
    batch_size: int | None = None,
    clock_drift_max: int | None = CLOCK_DRIFT_MAX_DEFAULT,
    identity_field: str = 'id',
    lock: SyncLock | None = None,
    on_complete: Callable[[DatabaseResult], None] | None = None,
    on_phase: Callable[[SyncPhase], None] | None = None,
    payload_bytes_max: int | None = None,
    payload_records_max: int | None = None,
    progress: Callable[[SyncStage, int, int], None] | None = None,
    reconciler: PayloadReconciler | None = None,
    transaction: Callable[[], AbstractContextManager[Any]] = nullcontext,
    transport: Transport | None = None,
) -> None:
    if not node_id:
        message = 'node_id must be a non-empty string'
        raise InvalidParameterError(message)

    if not identity_field:
        message = 'identity_field must be a non-empty string'
        raise InvalidParameterError(message)

    if batch_bytes is not None and batch_bytes < 1:
        message = (
            f'batch_bytes must be >= 1 '
            f'or None, got {batch_bytes}'
        )

        raise InvalidParameterError(message)

    if batch_size is not None and batch_size < 1:
        message = (
            f'batch_size must be >= 1 '
            f'or None, got {batch_size}'
        )

        raise InvalidParameterError(message)

    if clock_drift_max is not None and clock_drift_max < 0:
        message = (
            f'clock_drift_max must be non-negative '
            f'or None, got {clock_drift_max}'
        )

        raise InvalidParameterError(message)

    if payload_bytes_max is not None and payload_bytes_max < 1:
        message = (
            f'payload_bytes_max must be >= 1 '
            f'or None, got {payload_bytes_max}'
        )

        raise InvalidParameterError(message)

    if payload_records_max is not None and payload_records_max < 1:
        message = (
            f'payload_records_max must be >= 1 '
            f'or None, got {payload_records_max}'
        )

        raise InvalidParameterError(message)

    self._batch_bytes = batch_bytes
    self._batch_size = batch_size
    self._clock = clock
    self._clock_drift_max = clock_drift_max
    self._graph = graph
    self._identity_field = identity_field
    self._lock = lock
    self._node_id = node_id
    self._on_complete = on_complete
    self._on_phase = on_phase
    self._payload_bytes_max = payload_bytes_max
    self._payload_records_max = payload_records_max
    self._progress = progress
    self._reconciler = reconciler or PayloadReconciler()
    self._storage = storage
    self._transaction = transaction
    self._transport = transport

process

Source code in django_spire/contrib/sync/database/engine.py
def process(
    self,
    incoming: SyncManifest,
) -> tuple[SyncManifest, DatabaseResult]:
    result = DatabaseResult()

    self._validate_manifest(incoming)
    self._validate_clock(incoming)

    valid_payloads = self._validate_incoming_models(
        incoming, result,
    )

    with self._transaction():
        if self._lock:
            self._lock.hold(self._node_id)

        now = self._clock.now()

        response_payloads, has_more, after_keys = self._apply_incoming(
            valid_payloads,
            incoming.checkpoint,
            result,
            received_at=now,
            records_max=self._batch_size,
            bytes_max=self._batch_bytes,
            after_keys=incoming.after_keys,
        )

    checkpoint_value = self._max_response_checkpoint(
        incoming.checkpoint,
        response_payloads,
    )

    has_activity = (
        any(p.records or p.deletes for p in valid_payloads)
        or any(p.records or p.deletes for p in response_payloads)
    )

    if not has_more and not has_activity:
        checkpoint_value = max(checkpoint_value, now)

    response = SyncManifest(
        node_id=self._node_id,
        checkpoint=checkpoint_value,
        after_keys=after_keys,
        node_time=int(time.time()),
        payloads=response_payloads,
        has_more=has_more,
    )

    response.checksum = response.compute_checksum()

    self._finalize(result)

    return response, result

sync

Source code in django_spire/contrib/sync/database/engine.py
def sync(self, dry_run: bool = False) -> DatabaseResult:
    if self._transport is None:
        message = (
            'Transport is required for sync(). '
            'Use process() for server-side.'
        )

        raise TransportRequiredError(message)

    result = DatabaseResult()

    persisted = self._storage.get_after_keys(self._node_id)

    server_cursors: dict[str, Any] = {
        k.removeprefix('server:'): v
        for k, v in persisted.items()
        if k.startswith('server:')
    }

    collect_cursors: dict[str, Any] = {
        k.removeprefix('collect:'): v
        for k, v in persisted.items()
        if k.startswith('collect:')
    }

    with self._managed_session(result) as session_id:
        while True:
            self._enter_phase(
                SyncPhase.COLLECTING,
                session_id,
                SyncStage.VALIDATE,
            )

            checkpoint = self._storage.get_checkpoint(
                self._node_id,
            )

            manifest = self._collect(
                checkpoint,
                limit=self._batch_size,
                bytes_limit=self._batch_bytes,
                after_keys=collect_cursors,
            )

            manifest.after_keys = server_cursors
            manifest.checksum = manifest.compute_checksum()

            sent_snapshot = self._extract_record_snapshot(
                manifest,
            )

            self._record_pushed(manifest, result)

            self._enter_phase(
                SyncPhase.EXCHANGING, session_id,
                SyncStage.CLASSIFY,
            )

            response = self._exchange_and_validate(manifest)

            received_snapshot = self._extract_record_snapshot(
                response,
            )

            self._enter_phase(
                SyncPhase.RECONCILING, session_id,
                SyncStage.MUTATE,
            )

            if manifest.has_more:
                collect_cursors = {}

                for payload in manifest.payloads:
                    cursor = _last_cursor(payload.records)

                    if cursor:
                        collect_cursors[payload.model_label] = cursor
            else:
                collect_cursors = {}

            if response.has_more:
                server_cursors = response.after_keys
            else:
                server_cursors = {}

            if not dry_run:
                self._enter_phase(
                    SyncPhase.COMMITTING,
                    session_id,
                )

                self._commit(
                    checkpoint, response,
                    sent_snapshot, received_snapshot,
                    result,
                    server_cursors=server_cursors,
                    collect_cursors=collect_cursors,
                )

            converged = (
                not manifest.has_more
                and not response.has_more
            )

            if dry_run:
                break

            if converged:
                exchanged = (
                    any(p.records or p.deletes for p in manifest.payloads)
                    or any(p.records or p.deletes for p in response.payloads)
                )

                if not exchanged:
                    break

        self._enter_phase(SyncPhase.COMPLETE, session_id)
        self._finalize(result)

    self._log_sync_summary(result)

    return result