// Copyright (C) 2024 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT

import { CVATCore, MLModel } from '@root/cvat-core-wrapper';
import { SAM2TrackerAction } from './sam2-tracker-action';

interface SAM2TrackerBuilder {
    name: string;
    description: string;
    cvat: {
        lambda: {
            list: {
                leave: (
                    plugin: SAM2TrackerBuilder,
                    data: { models: MLModel[], count: number }
                ) => Promise<{ models: MLModel[], count: number }>
            },
        },
    }
    data: {
        supportedModelIDs: string[];
        isActionRegistered: boolean;
    };
}

const modelNamePrefix = 'pth-facebookresearch-sam2-tracker-';

export default function enableTrackerFeatures(core: CVATCore): void {
    core.plugins.register({
        name: 'Segment Anything 2 Tracker',
        description: 'Plugin enables Segment Anything 2 tracker-related features',
        cvat: {
            lambda: {
                list: {
                    async leave(plugin: SAM2TrackerBuilder, data: { models: MLModel[], count: number }) {
                        const filtered = data.models.filter((model) => (
                            typeof model.id === 'string' && plugin.data.supportedModelIDs.includes(model.id)
                        ));

                        if (filtered.length && !plugin.data.isActionRegistered) {
                            filtered.forEach((model) => {
                                const modelID = (model.id as string);
                                const weights = modelID.replace(modelNamePrefix, '') as 'large' | 'base-plus';
                                core.actions.register(new SAM2TrackerAction(modelID, weights, core));
                            });

                            plugin.data.isActionRegistered = true;
                        }

                        // we hide these two models from list as default
                        // tracking pipeline is disabled for them currently
                        return {
                            models: data.models.filter((model) => !filtered.includes(model)),
                            count: data.count - filtered.length,
                        };
                    },
                },
            },
        },
        data: {
            supportedModelIDs: [
                `${modelNamePrefix}large`,
                `${modelNamePrefix}base-plus`,
            ],
            isActionRegistered: false,
        },
    });
}
