import { Command, Transaction } from 'prosemirror-state';
import { Fragment, Node, Slice } from 'prosemirror-model';
import { ReplaceAroundStep } from 'prosemirror-transform';
import { findParentNodeOfType } from '../util';

export const claimPartIndent: Command = (state, dispatch) => {
  const claimPartOnCursor = findParentNodeOfType(state.schema.nodes.claimPart)(
    state.selection
  );
  const claimPartPos =
    claimPartOnCursor && state.doc.resolve(claimPartOnCursor.start);

  if (!claimPartOnCursor || !claimPartPos) {
    return false;
  }

  const claimPartType = state.schema.nodes.claimPart;

  const currentClaimPart = claimPartOnCursor.node;
  const currentIndentLevel = currentClaimPart.attrs.indentLevel ?? 0;

  const claimPartAbove = state.doc
    .resolve(claimPartOnCursor.pos - 1)
    .node(claimPartOnCursor.depth);

  if (!claimPartAbove) {
    return false;
  }

  const aboveIndentLevel = claimPartAbove.attrs.indentLevel ?? 0;

  if (currentIndentLevel !== aboveIndentLevel) {
    return false;
  }

  if (dispatch) {
    const indentLevel = currentIndentLevel + 1;
    const attrs = currentClaimPart.attrs;

    let tr = state.tr;

    tr.setNodeMarkup(claimPartOnCursor.pos, null, {
      ...currentClaimPart.attrs,
      numberingTemplate: attrs.numberingTemplate,
      numberingType: attrs.numberingType,
      indentLevel,
    });

    if (currentClaimPart) {
      tr = walkThroughClaimPart(
        currentClaimPart,
        claimPartOnCursor.pos,
        indentLevel,
        tr
      );
    }

    const slice = new Slice(Fragment.from(claimPartType.create()), 1, 0);

    const replaceAroundStep = new ReplaceAroundStep(
      claimPartOnCursor.pos - 1,
      claimPartOnCursor.pos + currentClaimPart.nodeSize,
      claimPartOnCursor.pos,
      claimPartOnCursor.pos + currentClaimPart.nodeSize,
      slice,
      0,
      true
    );

    tr = tr.step(replaceAroundStep);

    dispatch(tr.scrollIntoView());
  }
  return true;
};

function walkThroughClaimPart(
  node: Node,
  parentNodePos: number,
  parentIndentLevel: number,
  tr: Transaction
): Transaction {
  node.descendants((childNode, childPos) => {
    if (childNode.type.name !== 'claimPart') {
      return false;
    }

    const absolutePos = childPos + tr.mapping.map(parentNodePos + 1);
    const newIndentLevel = parentIndentLevel + 1;

    tr = tr.setNodeMarkup(absolutePos, null, {
      ...childNode.attrs,
      indentLevel: newIndentLevel,
    });

    // Recursively process the children of this claimPart
    if (childNode.childCount > 0) {
      tr = walkThroughClaimPart(childNode, absolutePos, newIndentLevel, tr);
    }

    return false; // Don't continue descending, as we're handling it manually
  });

  return tr;
}
