import {
  Command,
  EditorState,
  NodeSelection,
  TextSelection,
  Transaction,
} from 'prosemirror-state';
import { Node } from 'prosemirror-model';
import { canSplit } from 'prosemirror-transform';
import { findParentNodeOfType } from '../util';
import { LIST_NUMBERING } from '../schema/nodes/nodeNames';
import { NodeWithPos } from '../ClaimEditorProvider/editorView';

export const claimPartEnter = (
  startNewParagraph = true,
  listType?: string
): Command => {
  return (state, dispatch) => {
    if (state.selection.$from.parent.type.name === LIST_NUMBERING) {
      return true;
    }

    const claimPartOnCursor = findParentNodeOfType(
      state.schema.nodes.claimPart
    )(state.selection);

    if (!claimPartOnCursor) {
      return false;
    }

    const { $from } = state.selection;
    const paragraphStart = state.doc.resolve($from.start());

    // If cursor is at the start of the claimPart
    if ($from.pos === paragraphStart.pos) {
      const newClaimPart = state.schema.nodes.claimPart.createAndFill(
        {
          numberingTemplate: claimPartOnCursor.node.attrs.numberingTemplate,
          numberingType: 'none',
          indentLevel: claimPartOnCursor.node.attrs.indentLevel,
          startNewParagraph,
        },
        [state.schema.nodes.paragraph.create()]
      );

      if (!newClaimPart) {
        return false;
      }

      if (dispatch) {
        const tr = state.tr;
        tr.insert(claimPartOnCursor.pos, newClaimPart);

        tr.setNodeMarkup(claimPartOnCursor.pos + newClaimPart.nodeSize, null, {
          ...claimPartOnCursor.node.attrs,
          startNewParagraph: true,
        });

        const newSelectionPos = tr.doc.resolve(
          claimPartOnCursor.pos + newClaimPart.nodeSize + 1
        );
        tr.setSelection(TextSelection.near(newSelectionPos, 1));
        dispatch(tr.scrollIntoView());
      }

      return true;
    }

    // Split the claim part if the cursor is in the middle or end
    const maybeTr = splitClaimPart(
      state,
      {
        numberingTemplate: claimPartOnCursor.node.attrs.numberingTemplate,
        numberingType: listType
          ? listType
          : claimPartOnCursor.node.attrs.numberingType,
        indentLevel: Math.max(0, claimPartOnCursor.node.attrs.indentLevel),
        startNewParagraph,
      },
      claimPartOnCursor
    );

    if (maybeTr && dispatch) {
      dispatch(maybeTr);
    }

    return true;
  };

  function splitClaimPart(
    state: EditorState,
    itemAttrs: Record<string, unknown>,
    claimPartOnCursor: NodeWithPos
  ): Transaction | null {
    const itemType = state.schema.nodes.claimPart;
    const { $from, $to, node } = state.selection as NodeSelection;
    if ((node && node.isBlock) || $from.depth < 2 || !$from.sameParent($to)) {
      return null;
    }
    const grandParent = $from.node(-1);
    if (grandParent.type != itemType) {
      return null;
    }
    if (
      $from.parent.content.size == 0 &&
      $from.node(-1).childCount == $from.indexAfter(-1)
    ) {
      return null;
    }
    const nextType =
      $to.pos == $from.end() && itemAttrs.numberingType !== 'custom'
        ? grandParent.contentMatchAt(0).defaultType
        : null;
    const tr = state.tr.delete($from.pos, $to.pos);
    const types = nextType
      ? [
          { type: itemType, attrs: itemAttrs },
          { type: nextType, attrs: itemAttrs },
        ]
      : undefined;
    if (!canSplit(tr.doc, $from.pos, 2, types)) {
      return null;
    }
    tr.split($from.pos, 2, types);
    if (types === undefined) {
      const item = findParentNodeOfType(itemType)(tr.selection);
      if (item) {
        tr.setNodeMarkup(item.pos, itemType, itemAttrs);
        if (itemAttrs.numberingType === 'custom') {
          const claimPartNode = claimPartOnCursor.node;
          let listNumberingNode = null as Node | null;
          claimPartNode.forEach((childNode) => {
            if (childNode.type === state.schema.nodes[LIST_NUMBERING]) {
              listNumberingNode = childNode;
            }
          });
          const node = state.schema.nodes[LIST_NUMBERING].create(
            listNumberingNode?.attrs,
            listNumberingNode?.content
          );
          tr.insert(item.pos + 1, node);
        }
      }
    }
    return tr;
  }
};
