// Copyright (C) 2024-2025 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT

import { range, cloneDeep } from '@modules/lodash';

import {
    Job, Task, ActionParameterType,
    BaseCollectionAction, ShapeType, ObjectType,
    ObjectState, Source, CVATCore,
    ServerError,
} from '@root/cvat-core-wrapper';

type Collection = Parameters<BaseCollectionAction['run']>[0]['collection'];
type Track = Collection['tracks'][0];
type Shape = Collection['shapes'][0];

export class SAM2TrackerAction extends BaseCollectionAction {
    #modelID: string;
    #weights: 'large' | 'base-plus';
    #core: CVATCore
    #instance: Job | Task | null;
    #targetFrame: number;
    #convertPolygonShapesToTracks: boolean;

    public constructor(modelID: string, weights: 'large' | 'base-plus', core: CVATCore) {
        super();
        this.#convertPolygonShapesToTracks = false;
        this.#weights = weights;
        this.#modelID = modelID;
        this.#instance = null;
        this.#targetFrame = 0;
        this.#core = core;
    }

    public async init(instance: Job | Task, parameters: Record<string, string>): Promise<void> {
        this.#instance = instance;
        this.#targetFrame = +parameters['Target frame'];
        this.#convertPolygonShapesToTracks = parameters['Convert polygon shapes to tracks'] === 'true';
    }

    public async destroy(): Promise<void> {
        // nothing to destroy
    }

    public async run(
        {
            collection,
            frameData: { number },
            onProgress,
            cancelled,
        }: Parameters<BaseCollectionAction['run']>[0],
    ): ReturnType<BaseCollectionAction['run']> {
        const noChanges = {
            created: { shapes: [], tags: [], tracks: [] },
            deleted: { shapes: [], tags: [], tracks: [] },
        };

        if (this.#instance === null || number === this.#targetFrame) {
            return noChanges;
        }

        if (number >= this.#targetFrame) {
            throw new Error('AI propagation backward is not supported');
        }

        const frameNumbers = this.#instance instanceof Job ?
            await this.#instance.frames.frameNumbers() : range(0, this.#instance.size);
        const targetFrameNumbers = frameNumbers.filter(
            (frameNumber: number) => frameNumber >= Math.min(number, this.#targetFrame) &&
                frameNumber <= Math.max(number, this.#targetFrame) &&
                frameNumber !== number,
        );

        if (targetFrameNumbers.length === 0) {
            return noChanges;
        }

        const objectStates = await this.#instance.annotations.get(number, false, []);
        const tracks = [...collection.tracks]; // shallow copy value as it may have added new values
        const { shapes } = collection;

        const [
            initialShapes,
            initialStates,
            targetObjects,
            targetObjectStates,
        ] = ([] as (Shape | Track)[]).concat(shapes, tracks).reduce((acc, object) => {
            if (!Number.isInteger(object.clientID)){
                return acc;
            }

            const objectState = objectStates.find((_objectState) => _objectState.clientID === object.clientID);
            if (!objectState) {
                return acc;
            }

            acc[0].push({
                type: objectState.shapeType as ShapeType.MASK | ShapeType.POLYGON,
                points: objectState.points as number[],
            });
            acc[1].push(null);

            if (
                this.#convertPolygonShapesToTracks &&
                objectState.shapeType === ShapeType.POLYGON &&
                objectState.objectType === ObjectType.SHAPE
            ) {
                const castedObject = object as Required<Shape>;
                const convertedTrack = {
                    source: Source.AUTO,
                    attributes: [],
                    elements: [],
                    frame: object.frame,
                    group: object.group,
                    label_id: object.label_id,
                    shapes: [{
                        frame: object.frame,
                        attributes: [],
                        occluded: castedObject.occluded,
                        outside: false,
                        points: [...castedObject.points],
                        rotation: castedObject.rotation,
                        z_order: castedObject.z_order,
                        type: castedObject.type,
                    }],
                };

                tracks.push(convertedTrack);
                acc[2].push(convertedTrack);
                acc[3].push(new Proxy(objectState, {
                    get(objectState, p, receiver) {
                        if (p === 'objectType') {
                            return ObjectType.TRACK;
                        }

                        return Reflect.get(objectState, p, receiver);
                    }
                }));
            } else {
                acc[2].push(object);
                acc[3].push(objectState);
            }

            return acc;
        }, [[], [], [], []] as [
            { type: ShapeType.MASK | ShapeType.POLYGON, points: number[] }[],
            unknown[], (Shape | Track)[],
            ObjectState[],
        ]);

        const functionURL = `/api/lambda/functions/${this.#modelID}`;
        const initialResponse = await this.#core.server.request(functionURL, {
            method: 'post',
            data: {
                frame: number,
                shapes: initialShapes,
                states: initialStates,
                job: this.#instance.id,
                task: (this.#instance as Job).taskId,
            },
        });

        const totalFrames = targetFrameNumbers.length + 1;
        onProgress('Action is running', Math.ceil((1 / totalFrames) * 100));
        if (cancelled()) {
            return noChanges;
        }

        const trackedShapes: Shape[] = [];
        for (let i = 0; i < targetFrameNumbers.length; i++) {
            const frame = targetFrameNumbers[i];
            const meta = await this.#instance.frames.get(frame);
            if (meta.deleted) {
                continue;
            }

            let response = null;
            let retries = 10;
            while (retries) {
                try {
                    response = await this.#core.server.request(functionURL, {
                        method: 'post',
                        data: {
                            frame,
                            shapes: initialResponse.data.shapes.map(() => null),
                            states: initialResponse.data.states,
                            job: this.#instance.id,
                            task: (this.#instance as Job).taskId,
                        },
                    });

                    break;
                } catch (error) {
                    if (error instanceof ServerError && [0, 502, 503, 504].includes(error.code)) {
                        await new Promise((resolve) => setTimeout(resolve, 5000));
                        retries--;
                    } else {
                        throw error;
                    }
                }
            }

            onProgress('Action is running', Math.ceil(((i + 2) / totalFrames) * 100));
            if (cancelled()) {
                return noChanges;
            }

            response.data.shapes.forEach((shape: { type: ShapeType, points: number[] }, idx: number) => {
                const targetObjectState = targetObjectStates[idx];
                const targetObject = targetObjects[idx];
                const isVisibleShape = (
                    (shape.type === ShapeType.MASK && shape.points.length > 5) || // not empty RLE
                    (shape.type === ShapeType.POLYGON && shape.points.length >= 6) // or correct polygon
                );

                if (targetObjectState.objectType === ObjectType.TRACK) {
                    const nearestPreviousShape = (targetObject as Track).shapes.reduce((acc, val) => {
                        if (val.frame > acc.frame && val.frame < frame) {
                            return val;
                        }
                        return acc;
                    }, (targetObject as Track).shapes[0]);
                    const previousPoints = [...nearestPreviousShape.points as number[]];

                    const existingShape = (targetObject as Track).shapes
                        .find((_shape) => _shape.frame === frame) ?? null;
                    if (existingShape) {
                        existingShape.outside = !isVisibleShape;
                        existingShape.points = isVisibleShape ? shape.points : previousPoints;
                    } else {
                        if (!isVisibleShape && nearestPreviousShape.outside) {
                            // we do not need to add more than one "outside" shape in a row
                            return;
                        }

                        (targetObject as Track).shapes.push({
                            frame,
                            attributes: [],
                            occluded: targetObjectState.occluded,
                            outside: !isVisibleShape,
                            points: isVisibleShape ? shape.points : previousPoints,
                            rotation: 0,
                            z_order: targetObjectState.zOrder,
                            type: targetObjectState.shapeType,
                        });
                    }
                } else if (isVisibleShape) {
                    const castedObject = (targetObject as Shape);
                    trackedShapes.push({
                        elements: [],
                        group: targetObject.group,
                        frame,
                        source: Source.AUTO,
                        attributes: cloneDeep(targetObject.attributes),
                        occluded: castedObject.occluded,
                        outside: false,
                        points: shape.points,
                        rotation: 0,
                        z_order: castedObject.z_order,
                        label_id: castedObject.label_id,
                        type: castedObject.type,
                    });
                }
            });
        }

        return {
            created: { shapes: trackedShapes, tags: [], tracks },
            deleted: {
                shapes: this.#convertPolygonShapesToTracks ?
                    shapes.filter((shape) => shape.type === ShapeType.POLYGON) : [],
                // we remove all tracks that were just added
                tracks: collection.tracks,
                tags: [],
            },
        };
    }

    public applyFilter(
        input: Parameters<BaseCollectionAction['applyFilter']>[0],
    ): ReturnType<BaseCollectionAction['applyFilter']> {
        const { collection, frameData } = input;

        return {
            shapes: collection.shapes.filter((shape) => shape.frame === frameData.number &&
                [ShapeType.MASK, ShapeType.POLYGON].includes(shape.type)),
            tags: [],
            tracks: collection.tracks.filter((track) => {
                if (track.shapes[0].type !== ShapeType.POLYGON) {
                    return false;
                }

                // must be any shapes before current frame
                const shapesBefore = track.shapes.filter(
                    (shape) => shape.frame <= frameData.number).sort((a, b) => a.frame - b.frame,
                );
                if (shapesBefore.length === 0) {
                    return false;
                }

                // last shape must not be outside
                return !shapesBefore[shapesBefore.length - 1].outside;
            }),
        };
    }

    public isApplicableForObject(objectState: ObjectState): boolean {
        return [ObjectType.SHAPE, ObjectType.TRACK].includes(objectState.objectType) &&
            [ShapeType.MASK, ShapeType.POLYGON].includes(objectState.shapeType);
    }

    public get name(): BaseCollectionAction['name'] {
        return `Segment Anything 2: Tracker (${this.#weights})`;
    }

    public get parameters(): BaseCollectionAction['parameters'] {
        return {
            'Convert polygon shapes to tracks': {
                type: ActionParameterType.CHECKBOX,
                values: ['true', 'false'],
                defaultValue: String(this.#convertPolygonShapesToTracks),
            },
            'Target frame': {
                type: ActionParameterType.NUMBER,
                values: ({ instance }) => {
                    if (instance instanceof Job) {
                        return [instance.startFrame, instance.stopFrame, 1].map((val) => val.toString());
                    }
                    return [0, instance.size - 1, 1].map((val) => val.toString());
                },
                defaultValue: ({ instance }) => {
                    if (instance instanceof Job) {
                        return instance.stopFrame.toString();
                    }
                    return (instance.size - 1).toString();
                },
            },
        };
    }
}
