diff --git a/apps/obsidian/src/components/canvas/DiscourseRelationTool.ts b/apps/obsidian/src/components/canvas/DiscourseRelationTool.ts index 98a6502d5..64f1269bf 100644 --- a/apps/obsidian/src/components/canvas/DiscourseRelationTool.ts +++ b/apps/obsidian/src/components/canvas/DiscourseRelationTool.ts @@ -2,9 +2,12 @@ import { StateNode, TLEventHandlers, TLStateNodeConstructor } from "tldraw"; import { createShapeId } from "tldraw"; import type { TFile } from "obsidian"; import DiscourseGraphPlugin from "~/index"; -import { getRelationTypeById } from "~/utils/typeUtils"; +import { getNodeTypeById, getRelationTypeById } from "~/utils/typeUtils"; +import { + getCompatibleTargetNodeTypeIds, + getDiscourseNodeTypeId, +} from "~/components/canvas/utils/relationTypeUtils"; import { DiscourseRelationShape } from "./shapes/DiscourseRelationShape"; -import { getNodeTypeById } from "~/utils/typeUtils"; import { showToast } from "./utils/toastUtils"; import { toTldrawColor } from "~/utils/tldrawColors"; @@ -88,38 +91,6 @@ class Pointing extends StateNode { this.cancel(); }; - private getCompatibleNodeTypes = ( - plugin: DiscourseGraphPlugin, - relationTypeId: string, - sourceNodeTypeId: string, - ): string[] => { - const compatibleTypes: string[] = []; - - // Find all discourse relations that match the relation type and source - const relations = plugin.settings.discourseRelations.filter( - (relation) => - relation.relationshipTypeId === relationTypeId && - relation.sourceId === sourceNodeTypeId, - ); - - relations.forEach((relation) => { - compatibleTypes.push(relation.destinationId); - }); - - // Also check reverse relations (where current node is destination) - const reverseRelations = plugin.settings.discourseRelations.filter( - (relation) => - relation.relationshipTypeId === relationTypeId && - relation.destinationId === sourceNodeTypeId, - ); - - reverseRelations.forEach((relation) => { - compatibleTypes.push(relation.sourceId); - }); - - return [...new Set(compatibleTypes)]; // Remove duplicates - }; - override onEnter = () => { this.didTimeout = false; @@ -141,8 +112,7 @@ class Pointing extends StateNode { return; } - const sourceNodeTypeId = (target as { props?: { nodeTypeId?: string } }) - .props?.nodeTypeId; + const sourceNodeTypeId = getDiscourseNodeTypeId(target); if (!sourceNodeTypeId) { this.showWarning("Source node must have a valid node type"); return; @@ -150,11 +120,11 @@ class Pointing extends StateNode { // Check if this source node type can create relations of this type if (sourceNodeTypeId) { - const compatibleTargetTypes = this.getCompatibleNodeTypes( - plugin, + const compatibleTargetTypes = getCompatibleTargetNodeTypeIds({ + discourseRelations: plugin.settings.discourseRelations, relationTypeId, sourceNodeTypeId, - ); + }); if (compatibleTargetTypes.length === 0) { const sourceNodeType = getNodeTypeById(plugin, sourceNodeTypeId); diff --git a/apps/obsidian/src/components/canvas/TldrawViewComponent.tsx b/apps/obsidian/src/components/canvas/TldrawViewComponent.tsx index d196a04b6..c893e7b34 100644 --- a/apps/obsidian/src/components/canvas/TldrawViewComponent.tsx +++ b/apps/obsidian/src/components/canvas/TldrawViewComponent.tsx @@ -46,6 +46,7 @@ import { } from "~/components/canvas/shapes/DiscourseRelationBinding"; import ToastListener from "./ToastListener"; import { RelationsOverlay } from "./overlays/RelationOverlay"; +import { DragHandleOverlay } from "./overlays/DragHandleOverlay"; import { WHITE_LOGO_SVG } from "~/icons"; import { CustomContextMenu } from "./CustomContextMenu"; import { @@ -474,7 +475,10 @@ export const TldrawPreviewComponent = ({ ); }, InFrontOfTheCanvas: () => ( - + <> + + + ), }} /> diff --git a/apps/obsidian/src/components/canvas/overlays/DragHandleOverlay.tsx b/apps/obsidian/src/components/canvas/overlays/DragHandleOverlay.tsx new file mode 100644 index 000000000..ac588e3ec --- /dev/null +++ b/apps/obsidian/src/components/canvas/overlays/DragHandleOverlay.tsx @@ -0,0 +1,449 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import { TFile } from "obsidian"; +import { + TLArrowBindingProps, + TLShapeId, + createShapeId, + useEditor, + useValue, +} from "tldraw"; +import DiscourseGraphPlugin from "~/index"; +import { getRelationTypeById } from "~/utils/typeUtils"; +import { DiscourseNodeShape } from "~/components/canvas/shapes/DiscourseNodeShape"; +import { + DiscourseRelationShape, + DiscourseRelationUtil, +} from "~/components/canvas/shapes/DiscourseRelationShape"; +import { + createOrUpdateArrowBinding, + getArrowBindings, +} from "~/components/canvas/utils/relationUtils"; +import { DEFAULT_TLDRAW_COLOR } from "~/utils/tldrawColors"; +import { showToast } from "~/components/canvas/utils/toastUtils"; +import { + getDiscourseNodeAtPoint, + getDiscourseNodeTypeId, + hasValidRelationTypeForNodePair, +} from "~/components/canvas/utils/relationTypeUtils"; +import { RelationTypeDropdown } from "./RelationTypeDropdown"; + +type DragHandleOverlayProps = { + plugin: DiscourseGraphPlugin; + file: TFile; +}; + +type HandlePosition = { + x: number; + y: number; + anchor: { x: number; y: number }; +}; + +const HANDLE_RADIUS = 5; +const HANDLE_HIT_AREA = 12; +const HANDLE_PADDING = 8; // px offset in viewport space, outward from the node edge + +/** Page-space edge midpoints and their outward direction vectors. */ +const getEdgeMidpoints = (bounds: { + minX: number; + minY: number; + maxX: number; + maxY: number; +}): (HandlePosition & { direction: { x: number; y: number } })[] => { + return [ + // Top + { + x: (bounds.minX + bounds.maxX) / 2, + y: bounds.minY, + anchor: { x: 0.5, y: 0 }, + direction: { x: 0, y: -1 }, + }, + // Right + { + x: bounds.maxX, + y: (bounds.minY + bounds.maxY) / 2, + anchor: { x: 1, y: 0.5 }, + direction: { x: 1, y: 0 }, + }, + // Bottom + { + x: (bounds.minX + bounds.maxX) / 2, + y: bounds.maxY, + anchor: { x: 0.5, y: 1 }, + direction: { x: 0, y: 1 }, + }, + // Left + { + x: bounds.minX, + y: (bounds.minY + bounds.maxY) / 2, + anchor: { x: 0, y: 0.5 }, + direction: { x: -1, y: 0 }, + }, + ]; +}; + +export const DragHandleOverlay = ({ plugin, file }: DragHandleOverlayProps) => { + const editor = useEditor(); + const [pendingArrowId, setPendingArrowId] = useState(null); + const [isDragging, setIsDragging] = useState(false); + const sourceNodeRef = useRef(null); + const dragCleanupRef = useRef<(() => void) | null>(null); + + // Clean up drag listeners on unmount + useEffect(() => { + return () => { + dragCleanupRef.current?.(); + }; + }, []); + + // Track the single selected discourse node — mirrors RelationsOverlay pattern + const selectedNode = useValue( + "dragHandleSelectedNode", + () => { + const shape = editor.getOnlySelectedShape(); + if (shape && shape.type === "discourse-node") { + return shape as DiscourseNodeShape; + } + return null; + }, + [editor], + ); + + const handlePositions = useValue< + { left: number; top: number; anchor: { x: number; y: number } }[] | null + >( + "dragHandlePositions", + () => { + if (!selectedNode || pendingArrowId || isDragging) return null; + const bounds = editor.getShapePageBounds(selectedNode.id); + if (!bounds) return null; + const midpoints = getEdgeMidpoints(bounds); + return midpoints.map((mp) => { + const vp = editor.pageToViewport({ x: mp.x, y: mp.y }); + return { + left: vp.x + mp.direction.x * HANDLE_PADDING, + top: vp.y + mp.direction.y * HANDLE_PADDING, + anchor: mp.anchor, + }; + }); + }, + [editor, selectedNode?.id, pendingArrowId, isDragging], + ); + + const cleanupArrow = useCallback( + (arrowId: TLShapeId) => { + if (editor.getShape(arrowId)) { + editor.deleteShapes([arrowId]); + } + }, + [editor], + ); + + const handlePointerDown = useCallback( + (e: React.PointerEvent, anchor: { x: number; y: number }) => { + if (!selectedNode) return; + e.preventDefault(); + e.stopPropagation(); + + setIsDragging(true); + sourceNodeRef.current = selectedNode; + + const arrowId = createShapeId(); + + // Get the source node's page bounds for start position + const sourceBounds = editor.getShapePageBounds(selectedNode.id); + if (!sourceBounds) { + setIsDragging(false); + return; + } + + const startX = sourceBounds.minX + anchor.x * sourceBounds.width; + const startY = sourceBounds.minY + anchor.y * sourceBounds.height; + + // Create the arrow shape at the source node's position + editor.createShape({ + id: arrowId, + type: "discourse-relation", + x: startX, + y: startY, + props: { + color: DEFAULT_TLDRAW_COLOR, + relationTypeId: "", + text: "", + dash: "draw", + size: "m", + fill: "none", + labelColor: "black", + bend: 0, + start: { x: 0, y: 0 }, + end: { x: 0, y: 0 }, + arrowheadStart: "none", + arrowheadEnd: "arrow", + labelPosition: 0.5, + font: "draw", + scale: 1, + kind: "arc", + elbowMidPoint: 0, + }, + }); + + const createdShape = editor.getShape(arrowId); + if (!createdShape) { + setIsDragging(false); + return; + } + + // Bind the start handle to the source node + createOrUpdateArrowBinding(editor, createdShape, selectedNode.id, { + terminal: "start", + normalizedAnchor: anchor, + isPrecise: false, + isExact: false, + snap: "none", + }); + + // Select the arrow and start dragging the end handle + editor.select(arrowId); + + // Use tldraw's built-in handle dragging by setting the tool state + // We need to track the pointer to update the end handle + const containerEl = editor.getContainer(); + const onPointerMove = (moveEvent: PointerEvent) => { + const point = editor.screenToPage({ + x: moveEvent.clientX, + y: moveEvent.clientY, + }); + + // Update the arrow's end position + const currentShape = editor.getShape(arrowId); + if (!currentShape) return; + + const dx = point.x - currentShape.x; + const dy = point.y - currentShape.y; + + // Check for a target shape under the cursor + const target = getDiscourseNodeAtPoint(editor, point, selectedNode.id); + + if (target) { + // Bind end to target + createOrUpdateArrowBinding(editor, currentShape, target.id, { + terminal: "end", + normalizedAnchor: { x: 0.5, y: 0.5 }, + isPrecise: false, + isExact: false, + snap: "none", + }); + editor.setHintingShapes([target.id]); + } else { + // Update free end position + // Remove any existing end binding + const bindings = getArrowBindings(editor, currentShape); + if (bindings.end) { + editor.deleteBindings( + editor + .getBindingsFromShape(currentShape.id, "discourse-relation") + .filter( + (b) => (b.props as TLArrowBindingProps).terminal === "end", + ), + ); + } + editor.updateShapes([ + { + id: arrowId, + type: "discourse-relation", + props: { end: { x: dx, y: dy } }, + }, + ]); + editor.setHintingShapes([]); + } + }; + + const onPointerUp = () => { + containerEl.removeEventListener("pointermove", onPointerMove); + containerEl.removeEventListener("pointerup", onPointerUp); + dragCleanupRef.current = null; + editor.setHintingShapes([]); + setIsDragging(false); + + const finalShape = editor.getShape(arrowId); + if (!finalShape) return; + + const bindings = getArrowBindings(editor, finalShape); + + // Validate: both ends bound to different discourse nodes + if ( + bindings.start && + bindings.end && + bindings.start.toId !== bindings.end.toId + ) { + const endTarget = editor.getShape(bindings.end.toId); + if (endTarget && endTarget.type === "discourse-node") { + // Check if any relation types are valid for this node pair + const startNodeTypeId = getDiscourseNodeTypeId( + editor.getShape(bindings.start.toId), + ); + const endNodeTypeId = getDiscourseNodeTypeId(endTarget); + + const hasValidRelationType = + startNodeTypeId && + endNodeTypeId && + hasValidRelationTypeForNodePair({ + settings: plugin.settings, + sourceNodeTypeId: startNodeTypeId, + targetNodeTypeId: endNodeTypeId, + }); + + if (!hasValidRelationType) { + cleanupArrow(arrowId); + showToast({ + severity: "warning", + title: "Relation", + description: + "No relation types are defined between these node types", + targetCanvasId: file.path, + }); + if (sourceNodeRef.current) { + editor.select(sourceNodeRef.current.id); + } + sourceNodeRef.current = null; + return; + } + + // Success - show dropdown to pick relation type + setPendingArrowId(arrowId); + editor.select(arrowId); + return; + } + } + + // Failure - clean up the arrow and show notice + cleanupArrow(arrowId); + showToast({ + severity: "warning", + title: "Relation", + description: !bindings.end + ? "Drop on a discourse node to create a relation" + : "Target must be a different discourse node", + targetCanvasId: file.path, + }); + // Re-select the source node + if (sourceNodeRef.current) { + editor.select(sourceNodeRef.current.id); + } + sourceNodeRef.current = null; + }; + + containerEl.addEventListener("pointermove", onPointerMove); + containerEl.addEventListener("pointerup", onPointerUp); + + dragCleanupRef.current = () => { + containerEl.removeEventListener("pointermove", onPointerMove); + containerEl.removeEventListener("pointerup", onPointerUp); + dragCleanupRef.current = null; + }; + }, + [selectedNode, editor, cleanupArrow, file.path, plugin.settings], + ); + + const handleDropdownSelect = useCallback( + (relationTypeId: string) => { + if (!pendingArrowId) return; + + const shape = editor.getShape(pendingArrowId); + if (!shape) { + setPendingArrowId(null); + return; + } + + const relationType = getRelationTypeById(plugin, relationTypeId); + if (!relationType) { + cleanupArrow(pendingArrowId); + setPendingArrowId(null); + return; + } + + // Update arrow props with relation type info + editor.updateShapes([ + { + id: pendingArrowId, + type: "discourse-relation", + props: { + relationTypeId, + color: relationType.color, + }, + }, + ]); + + // Get updated shape and bindings for text direction + const updatedShape = + editor.getShape(pendingArrowId); + if (updatedShape) { + const bindings = getArrowBindings(editor, updatedShape); + + // Update text based on direction + const util = editor.getShapeUtil(updatedShape); + if (util instanceof DiscourseRelationUtil) { + util.updateRelationTextForDirection(updatedShape, bindings); + // Persist to relations JSON + void util.reifyRelation(updatedShape, bindings); + } + } + + setPendingArrowId(null); + sourceNodeRef.current = null; + }, + [editor, pendingArrowId, plugin, cleanupArrow], + ); + + const handleDropdownDismiss = useCallback(() => { + if (pendingArrowId) { + cleanupArrow(pendingArrowId); + setPendingArrowId(null); + } + // Re-select source node + if (sourceNodeRef.current) { + editor.select(sourceNodeRef.current.id); + } + sourceNodeRef.current = null; + }, [editor, pendingArrowId, cleanupArrow]); + + const showHandles = !!handlePositions && !pendingArrowId; + + return ( +
+ {/* Drag handle dots */} + {showHandles && + handlePositions.map((pos, i) => ( +
handlePointerDown(e, pos.anchor)} + className="pointer-events-auto absolute z-20 flex cursor-crosshair items-center justify-center" + style={{ + left: `${pos.left}px`, + top: `${pos.top}px`, + width: `${HANDLE_HIT_AREA * 2}px`, + height: `${HANDLE_HIT_AREA * 2}px`, + transform: "translate(-50%, -50%)", + }} + > +
+
+ ))} + + {/* Relation type dropdown */} + {pendingArrowId && ( + + )} +
+ ); +}; diff --git a/apps/obsidian/src/components/canvas/overlays/RelationOverlay.tsx b/apps/obsidian/src/components/canvas/overlays/RelationOverlay.tsx index 048e46b91..a3baff93d 100644 --- a/apps/obsidian/src/components/canvas/overlays/RelationOverlay.tsx +++ b/apps/obsidian/src/components/canvas/overlays/RelationOverlay.tsx @@ -87,7 +87,7 @@ export const RelationsOverlay = ({ plugin, file }: RelationsOverlayProps) => { maxHeight: "calc(100% - 24px)", pointerEvents: "all", overflow: "auto", - zIndex: 10, + zIndex: 25, }} onMouseDown={(e) => e.stopPropagation()} onMouseUp={(e) => e.stopPropagation()} diff --git a/apps/obsidian/src/components/canvas/overlays/RelationTypeDropdown.tsx b/apps/obsidian/src/components/canvas/overlays/RelationTypeDropdown.tsx new file mode 100644 index 000000000..5c494c401 --- /dev/null +++ b/apps/obsidian/src/components/canvas/overlays/RelationTypeDropdown.tsx @@ -0,0 +1,160 @@ +import { useCallback, useEffect, useMemo, useRef } from "react"; +import { TLShapeId, useEditor, useValue } from "tldraw"; +import DiscourseGraphPlugin from "~/index"; +import { DiscourseRelationShape } from "~/components/canvas/shapes/DiscourseRelationShape"; +import { + getArrowBindings, + getArrowInfo, +} from "~/components/canvas/utils/relationUtils"; +import { + getDiscourseNodeTypeId, + getValidRelationTypesForNodePair, +} from "~/components/canvas/utils/relationTypeUtils"; + +type RelationTypeDropdownProps = { + arrowId: TLShapeId; + plugin: DiscourseGraphPlugin; + onSelect: (relationTypeId: string) => void; + onDismiss: () => void; +}; + +export const RelationTypeDropdown = ({ + arrowId, + plugin, + onSelect, + onDismiss, +}: RelationTypeDropdownProps) => { + const editor = useEditor(); + const dropdownRef = useRef(null); + + const arrow = useValue( + "dropdownArrow", + () => editor.getShape(arrowId) ?? null, + [editor, arrowId], + ); + + // Auto-dismiss if arrow is deleted + useEffect(() => { + if (!arrow) { + onDismiss(); + } + }, [arrow, onDismiss]); + + // Get valid relation types based on source/target node types + const validRelationTypes = useMemo(() => { + if (!arrow) return []; + + const bindings = getArrowBindings(editor, arrow); + if (!bindings.start || !bindings.end) return []; + + const startNode = editor.getShape(bindings.start.toId); + const endNode = editor.getShape(bindings.end.toId); + + if (!startNode || !endNode) return []; + + const startNodeTypeId = getDiscourseNodeTypeId(startNode); + const endNodeTypeId = getDiscourseNodeTypeId(endNode); + + if (!startNodeTypeId || !endNodeTypeId) return []; + + return getValidRelationTypesForNodePair({ + settings: plugin.settings, + sourceNodeTypeId: startNodeTypeId, + targetNodeTypeId: endNodeTypeId, + }); + }, [arrow, editor, plugin]); + + // Position dropdown at arrow midpoint + const dropdownPosition = useValue<{ left: number; top: number } | null>( + "dropdownPosition", + () => { + if (!arrow) return null; + + const info = getArrowInfo(editor, arrow); + if (!info) return null; + + // Get the midpoint in page space + const pageTransform = editor.getShapePageTransform(arrow.id); + const midInPage = pageTransform.applyToPoint(info.middle); + + const vp = editor.pageToViewport(midInPage); + return { left: vp.x, top: vp.y }; + }, + [editor, arrow?.id], + ); + + // Handle click outside + useEffect(() => { + const handlePointerDown = (e: PointerEvent) => { + if ( + dropdownRef.current && + !dropdownRef.current.contains(e.target as Node) + ) { + onDismiss(); + } + }; + + // Delay to avoid immediately triggering from the pointer up that opened this + const timer = setTimeout(() => { + window.addEventListener("pointerdown", handlePointerDown, true); + }, 100); + + return () => { + clearTimeout(timer); + window.removeEventListener("pointerdown", handlePointerDown, true); + }; + }, [onDismiss]); + + // Handle Escape key + useEffect(() => { + const handleKeyDown = (e: KeyboardEvent) => { + if (e.key === "Escape") { + onDismiss(); + } + }; + window.addEventListener("keydown", handleKeyDown, true); + return () => window.removeEventListener("keydown", handleKeyDown, true); + }, [onDismiss]); + + const handleSelect = useCallback( + (relationTypeId: string) => { + onSelect(relationTypeId); + }, + [onSelect], + ); + + if (!dropdownPosition || !arrow) return null; + + return ( +
e.stopPropagation()} + onPointerUp={(e) => e.stopPropagation()} + onClick={(e) => e.stopPropagation()} + > +
+
+ Relation Type +
+ {validRelationTypes.map((rt) => ( + + ))} +
+
+ ); +}; diff --git a/apps/obsidian/src/components/canvas/shapes/DiscourseRelationBinding.tsx b/apps/obsidian/src/components/canvas/shapes/DiscourseRelationBinding.tsx index e0161d24f..738a207a5 100644 --- a/apps/obsidian/src/components/canvas/shapes/DiscourseRelationBinding.tsx +++ b/apps/obsidian/src/components/canvas/shapes/DiscourseRelationBinding.tsx @@ -159,8 +159,8 @@ export class BaseRelationBindingUtil extends BindingUtil { BaseRelationBindingUtil.reifiedArrows.add(arrow.id); const util = editor.getShapeUtil(arrow); if (util instanceof DiscourseRelationUtil) { - util.reifyRelationInFrontmatter(arrow, bindings).catch((error) => { - console.error("Failed to reify relation in frontmatter:", error); + util.reifyRelation(arrow, bindings).catch((error) => { + console.error("Failed to reify relation:", error); // Remove from reified set on error so it can be retried BaseRelationBindingUtil.reifiedArrows.delete(arrow.id); }); @@ -433,4 +433,4 @@ function intersectLineSegmentCircle( if (result.length === 0) return null; // no intersection return result; -} \ No newline at end of file +} diff --git a/apps/obsidian/src/components/canvas/shapes/DiscourseRelationShape.tsx b/apps/obsidian/src/components/canvas/shapes/DiscourseRelationShape.tsx index e3b56bc24..d07ed01eb 100644 --- a/apps/obsidian/src/components/canvas/shapes/DiscourseRelationShape.tsx +++ b/apps/obsidian/src/components/canvas/shapes/DiscourseRelationShape.tsx @@ -60,6 +60,13 @@ import { import { RelationBindings } from "./DiscourseRelationBinding"; import { DiscourseNodeShape, DiscourseNodeUtil } from "./DiscourseNodeShape"; import { addRelationToRelationsJson } from "~/components/canvas/utils/relationJsonUtils"; +import { + getDiscourseNodeAtPoint, + getDiscourseNodeTypeId, + getRelationDirection, + isValidRelationConnection, +} from "~/components/canvas/utils/relationTypeUtils"; +import { getNodeTypeById, getRelationTypeById } from "~/utils/typeUtils"; import { showToast } from "~/components/canvas/utils/toastUtils"; export enum ArrowHandles { @@ -241,29 +248,10 @@ export class DiscourseRelationUtil extends ShapeUtil { .getShapePageTransform(shape.id) .applyToPoint(info.handle); - const target = this.editor.getShapeAtPoint(point, { - hitInside: true, - hitFrameInside: true, - margin: 0, - filter: (targetShape) => { - return ( - !targetShape.isLocked && - this.editor.canBindShapes({ - fromShape: shape, - toShape: targetShape, - binding: shape.type, - }) - ); - }, - }); + // Exclude the arrow shape itself to avoid self-binding on initial drag + const target = getDiscourseNodeAtPoint(this.editor, point, shape.id); - if ( - !target || - // TODO - this is a hack/fix - // the shape is targeting itself on initial drag - // find out why - target.id === shape.id - ) { + if (!target) { // TODO re-implement this on pointer up // if ( // currentBinding && @@ -354,28 +342,29 @@ export class DiscourseRelationUtil extends ShapeUtil { ) { const sourceNodeId = otherBinding.toId; const sourceNode = this.editor.getShape(sourceNodeId); - const targetNodeTypeId = (target as { props?: { nodeTypeId?: string } }) - .props?.nodeTypeId; - const sourceNodeTypeId = ( - sourceNode as { props?: { nodeTypeId?: string } } | null - )?.props?.nodeTypeId; + const targetNodeTypeId = getDiscourseNodeTypeId(target); + const sourceNodeTypeId = getDiscourseNodeTypeId(sourceNode); if (sourceNodeTypeId && targetNodeTypeId && shape.props.relationTypeId) { - const isValidConnection = this.isValidNodeConnection( + const isValidConnection = isValidRelationConnection({ + discourseRelations: this.options.plugin.settings.discourseRelations, + relationTypeId: shape.props.relationTypeId, sourceNodeTypeId, targetNodeTypeId, - shape.props.relationTypeId, - ); + }); if (!isValidConnection) { - const sourceNodeType = this.options.plugin.settings.nodeTypes.find( - (nt) => nt.id === sourceNodeTypeId, + const sourceNodeType = getNodeTypeById( + this.options.plugin, + sourceNodeTypeId, ); - const targetNodeType = this.options.plugin.settings.nodeTypes.find( - (nt) => nt.id === targetNodeTypeId, + const targetNodeType = getNodeTypeById( + this.options.plugin, + targetNodeTypeId, ); - const relationType = this.options.plugin.settings.relationTypes.find( - (rt) => rt.id === shape.props.relationTypeId, + const relationType = getRelationTypeById( + this.options.plugin, + shape.props.relationTypeId, ); // Show error toast and delete the entire relation shape @@ -476,7 +465,8 @@ export class DiscourseRelationUtil extends ShapeUtil { // Check if other shapes are also being translated const selectedShapeIds = this.editor.getSelectedShapeIds(); - const onlyRelationSelected = selectedShapeIds.length === 1 && selectedShapeIds[0] === shape.id; + const onlyRelationSelected = + selectedShapeIds.length === 1 && selectedShapeIds[0] === shape.id; // If both ends are bound AND only the relation is selected, convert translation to bend changes // If other shapes are also selected, do a simple translation instead @@ -1085,38 +1075,25 @@ export class DiscourseRelationUtil extends ShapeUtil { if (!startNode || !endNode) return; - const startNodeTypeId = (startNode as { props?: { nodeTypeId?: string } }) - ?.props?.nodeTypeId; - const endNodeTypeId = (endNode as { props?: { nodeTypeId?: string } }) - ?.props?.nodeTypeId; + const startNodeTypeId = getDiscourseNodeTypeId(startNode); + const endNodeTypeId = getDiscourseNodeTypeId(endNode); if (!startNodeTypeId || !endNodeTypeId) return; - const relationType = plugin.settings.relationTypes.find( - (rt) => rt.id === relationTypeId, - ); + const relationType = getRelationTypeById(plugin, relationTypeId); if (!relationType) return; - // Check if this is a direct connection (start -> end) - const isDirectConnection = plugin.settings.discourseRelations.some( - (relation) => - relation.relationshipTypeId === relationTypeId && - relation.sourceId === startNodeTypeId && - relation.destinationId === endNodeTypeId, - ); - - // Check if this is a reverse connection (end -> start, so we need complement) - const isReverseConnection = plugin.settings.discourseRelations.some( - (relation) => - relation.relationshipTypeId === relationTypeId && - relation.sourceId === endNodeTypeId && - relation.destinationId === startNodeTypeId, - ); + const { direct, reverse } = getRelationDirection({ + discourseRelations: plugin.settings.discourseRelations, + relationTypeId, + sourceNodeTypeId: startNodeTypeId, + targetNodeTypeId: endNodeTypeId, + }); let newText = relationType.label; // Default to main label - if (isReverseConnection && !isDirectConnection) { + if (reverse && !direct) { // This is purely a reverse connection, use complement newText = relationType.complement; } @@ -1142,35 +1119,19 @@ export class DiscourseRelationUtil extends ShapeUtil { targetNodeTypeId: string, relationTypeId: string, ): boolean { - const plugin = this.options.plugin; - - // Check direct connection (source -> target) - const directConnection = plugin.settings.discourseRelations.some( - (relation) => - relation.relationshipTypeId === relationTypeId && - relation.sourceId === sourceNodeTypeId && - relation.destinationId === targetNodeTypeId, - ); - - if (directConnection) return true; - - // Check reverse connection (target -> source) - // This handles bidirectional relations where the complement is used - const reverseConnection = plugin.settings.discourseRelations.some( - (relation) => - relation.relationshipTypeId === relationTypeId && - relation.sourceId === targetNodeTypeId && - relation.destinationId === sourceNodeTypeId, - ); - - return reverseConnection; + return isValidRelationConnection({ + discourseRelations: this.options.plugin.settings.discourseRelations, + relationTypeId, + sourceNodeTypeId, + targetNodeTypeId, + }); } /** - * Reifies the relation in the frontmatter of both connected files. + * Reifies the relation in the relations JSON of both connected files. * This creates the bidirectional links that make the relation persistent. */ - async reifyRelationInFrontmatter( + async reifyRelation( shape: DiscourseRelationShape, bindings: RelationBindings, ): Promise { @@ -1231,8 +1192,9 @@ export class DiscourseRelationUtil extends ShapeUtil { }); } - const relationType = this.options.plugin.settings.relationTypes.find( - (rt) => rt.id === shape.props.relationTypeId, + const relationType = getRelationTypeById( + this.options.plugin, + shape.props.relationTypeId, ); if (relationType && !alreadyExisted) { @@ -1243,7 +1205,7 @@ export class DiscourseRelationUtil extends ShapeUtil { }); } } catch (error) { - console.error("Failed to reify relation in frontmatter:", error); + console.error("Failed to reify relation:", error); showToast({ severity: "error", title: "Failed to Save Relation", diff --git a/apps/obsidian/src/components/canvas/utils/relationTypeUtils.ts b/apps/obsidian/src/components/canvas/utils/relationTypeUtils.ts new file mode 100644 index 000000000..bbd39b4bb --- /dev/null +++ b/apps/obsidian/src/components/canvas/utils/relationTypeUtils.ts @@ -0,0 +1,181 @@ +import type { Editor, TLShape, TLShapeId, VecLike } from "tldraw"; +import type { DiscourseRelation, DiscourseRelationType } from "~/types"; +import { COLOR_PALETTE } from "~/utils/tldrawColors"; + +/** + * Finds the discourse node shape at a given page point, excluding an optional + * shape by ID (e.g. the source node or the relation arrow itself). + */ +export const getDiscourseNodeAtPoint = ( + editor: Editor, + point: VecLike, + excludeShapeId?: TLShapeId, +): TLShape | undefined => { + return editor.getShapeAtPoint(point, { + hitInside: true, + hitFrameInside: true, + margin: 0, + filter: (targetShape) => + targetShape.type === "discourse-node" && + !targetShape.isLocked && + targetShape.id !== excludeShapeId, + }); +}; + +/** + * Extracts the nodeTypeId from any tldraw shape that may have it. + * Avoids repeating the same unsafe cast across multiple files. + */ +export const getDiscourseNodeTypeId = (shape: unknown): string | undefined => { + const typed = shape as { props?: { nodeTypeId?: string } } | null | undefined; + return typed?.props?.nodeTypeId; +}; + +type RelationTypeSettings = { + discourseRelations: DiscourseRelation[]; + relationTypes: DiscourseRelationType[]; +}; + +/** + * Checks the direction of a discourse relation between two node types. + * Returns whether the relation exists in the direct (source→target) + * and/or reverse (target→source) direction. + */ +export const getRelationDirection = ({ + discourseRelations, + relationTypeId, + sourceNodeTypeId, + targetNodeTypeId, +}: { + discourseRelations: DiscourseRelation[]; + relationTypeId: string; + sourceNodeTypeId: string; + targetNodeTypeId: string; +}): { direct: boolean; reverse: boolean } => { + let direct = false; + let reverse = false; + + for (const relation of discourseRelations) { + if (relation.relationshipTypeId !== relationTypeId) continue; + if ( + relation.sourceId === sourceNodeTypeId && + relation.destinationId === targetNodeTypeId + ) { + direct = true; + } + if ( + relation.sourceId === targetNodeTypeId && + relation.destinationId === sourceNodeTypeId + ) { + reverse = true; + } + if (direct && reverse) break; + } + + return { direct, reverse }; +}; + +/** + * Returns the list of valid relation types for a given pair of node types, + * checking both directions of the discourse relations. + */ +export const getValidRelationTypesForNodePair = ({ + settings, + sourceNodeTypeId, + targetNodeTypeId, +}: { + settings: RelationTypeSettings; + sourceNodeTypeId: string; + targetNodeTypeId: string; +}): { id: string; label: string; color: string }[] => { + const validTypes: { id: string; label: string; color: string }[] = []; + + for (const relationType of settings.relationTypes) { + const { direct, reverse } = getRelationDirection({ + discourseRelations: settings.discourseRelations, + relationTypeId: relationType.id, + sourceNodeTypeId, + targetNodeTypeId, + }); + + if (direct || reverse) { + validTypes.push({ + id: relationType.id, + label: relationType.label, + color: COLOR_PALETTE[relationType.color] ?? COLOR_PALETTE["black"]!, + }); + } + } + + return validTypes; +}; + +/** + * Checks whether a specific relation type can connect the given source and + * target node types (in either direction). + */ +export const isValidRelationConnection = ({ + discourseRelations, + relationTypeId, + sourceNodeTypeId, + targetNodeTypeId, +}: { + discourseRelations: DiscourseRelation[]; + relationTypeId: string; + sourceNodeTypeId: string; + targetNodeTypeId: string; +}): boolean => { + const { direct, reverse } = getRelationDirection({ + discourseRelations, + relationTypeId, + sourceNodeTypeId, + targetNodeTypeId, + }); + return direct || reverse; +}; + +/** + * Returns the valid target node type IDs for a given relation type and source + * node type, checking both forward and reverse directions. + */ +export const getCompatibleTargetNodeTypeIds = ({ + discourseRelations, + relationTypeId, + sourceNodeTypeId, +}: { + discourseRelations: DiscourseRelation[]; + relationTypeId: string; + sourceNodeTypeId: string; +}): string[] => { + const targets = new Set(); + for (const relation of discourseRelations) { + if (relation.relationshipTypeId !== relationTypeId) continue; + if (relation.sourceId === sourceNodeTypeId) + targets.add(relation.destinationId); + if (relation.destinationId === sourceNodeTypeId) + targets.add(relation.sourceId); + } + return [...targets]; +}; + +/** + * Checks whether any valid relation type exists between two node types. + */ +export const hasValidRelationTypeForNodePair = ({ + settings, + sourceNodeTypeId, + targetNodeTypeId, +}: { + settings: RelationTypeSettings; + sourceNodeTypeId: string; + targetNodeTypeId: string; +}): boolean => { + return settings.discourseRelations.some( + (r) => + settings.relationTypes.some((rt) => rt.id === r.relationshipTypeId) && + ((r.sourceId === sourceNodeTypeId && + r.destinationId === targetNodeTypeId) || + (r.sourceId === targetNodeTypeId && + r.destinationId === sourceNodeTypeId)), + ); +};