# Import Virtualitics AI Application internal elements
from virtualitics_sdk import (
    App, Step, StepType, Page, Section, Card, StoreInterface, InfographData, Infographic, 
    InfographicOrientation, CustomEvent, DataSource, Image, Table, ImageSize, Column, Row, 
    InfographDataType, Dropdown, DateTimeRange, PlotlyPlot
)

# Import external packages
import plotly.express as px
from PIL import Image as PILImage
from io import BytesIO
import pandas as pd
import string


def ExecutiveUpdater(store_interface: StoreInterface):
    page = store_interface.get_page()
    card = page.get_card_by_title("Executive Dashboard Plots")

    selected_country = store_interface.get_element_value(
        store_interface.step_name, "Countries"
    )
    date_range = store_interface.get_element_value(
        store_interface.step_name, "Date Range"
    )
    start_date = date_range["min"]
    end_date = date_range["max"]
    saved_df = store_interface.get_input("Predict Data")
    df = saved_df.loc[
        (saved_df["Country"].isin(selected_country))
        & (saved_df["InvoiceDate"].dt.strftime("%Y-%m-%d") >= start_date)
        & (saved_df["InvoiceDate"].dt.strftime("%Y-%m-%d") <= end_date)
    ]

    sales_value_by_product = df.groupby("Description").agg(
        TotalSalesValue=("SaleValue", "sum"),
        TotalQuantitySold=("Quantity", "sum"),
        MonthsSoldFor=("InvoiceMonth", "nunique"),
    )
    top_products = (
        sales_value_by_product.sort_values("TotalSalesValue", ascending=False)
        .head(8)
        .reset_index(level=0)
    )

    # Update the top products bar plot
    fig_top_products = px.bar(
        top_products,
        x="Description",
        y="TotalSalesValue",
        title="Top Revenue-Generating Products",
        labels={"Description": "Product Description", "TotalSalesValue": "Total Sales Value"},
        text="TotalSalesValue",
        template="predict_default"
    )
    fig_top_products.update_traces(texttemplate='%{text:.2s}', textposition='outside')
    top_products_bar = PlotlyPlot(fig_top_products)

    # Extract the total sales count by Country and create barplot of most actively buying nations outside of the UK
    country_counts = df.Country.value_counts()
    country_counts = pd.DataFrame(
        {"Country": country_counts.index, "Counts": country_counts.values}
    )
    sorted_country_counts = country_counts.sort_values("Counts", ascending=False)

    fig_buyers = px.bar(
        sorted_country_counts,
        x="Country",
        y="Counts",
        title="Most Orders by Market",
        labels={"Country": "Country", "Counts": "Order Counts"},
        text="Counts",
        template="predict_default"
    )
    fig_buyers.update_traces(texttemplate='%{text}', textposition='outside')
    buyers_bar = PlotlyPlot(fig_buyers)

    # Calculate the total sales for each month and create a line plot
    sales_by_month = df.groupby("InvoiceMonth").agg(
        TotalSalesValue=("SaleValue", "sum")
    )
    sales_by_month = (
        sales_by_month.sort_values("InvoiceMonth", ascending=True)
        .reset_index(level=0)
        .head(-1)
    )

    # Create the line plot
    sales_by_month_fig = px.line(
        sales_by_month,
        x="InvoiceMonth",
        y="TotalSalesValue",
        title="Total Sales Revenue by Month",
        labels={"InvoiceMonth": "Invoice Month", "TotalSalesValue": "Total Sales Revenue"},
        markers=True,
        template="predict_default"
    )
    sales_by_month_line = PlotlyPlot(sales_by_month_fig)

    # Previous month comparison CustomerIDs
    sales_by_client_month = df.groupby(["CustomerID", "InvoiceMonth"]).agg(
        CurrentMonthSales=("SaleValue", "sum")
    )
    sales_by_client_month = sales_by_client_month.sort_values(
        ["CustomerID", "InvoiceMonth"], ascending=True
    ).reset_index()
    sales_by_client_month["PreviousMonthSales"] = (
        sales_by_client_month.sort_values(["CustomerID", "InvoiceMonth"])
        .groupby("CustomerID")["CurrentMonthSales"]
        .shift(1)
    )

    # Get current month recurring clients
    current_sales_by_client_month = sales_by_client_month[
        sales_by_client_month.InvoiceMonth == sales_by_client_month.InvoiceMonth.max()
    ]
    current_clients_count = current_sales_by_client_month.shape[0]
    recurring_customers_ratio = (
        current_sales_by_client_month["PreviousMonthSales"].count()
        / current_clients_count
        * 100
    )

    # Get products sold consistently over time
    consistently_sold_products = (
        sales_value_by_product.sort_values("MonthsSoldFor", ascending=False)
        .head(2)
        .reset_index(level=0)
    )

    # Get products sold in highest quantities
    high_quantity_sold_products = (
        sales_value_by_product.sort_values("TotalQuantitySold", ascending=False)
        .head(2)
        .reset_index(level=0)
    )

    # Calculate the total sales value for this calendar year
    current_year = df["InvoiceDate"].dt.year.max()
    ytd_rev = df[df["InvoiceDate"].dt.year == current_year]["SaleValue"].sum()

    # Create the InfographData elements needed and attach them to the Infographic
    executive_info_blocks = [
        InfographData(
            label="Recurring Customers",
            main_text=f"{int(recurring_customers_ratio)}%",
            supporting_text="of customers are retained month-to-month",
            icon="check_circle",
            _type=InfographDataType.POSITIVE,
        ),
        InfographData(
            label="Consistently Sold Products",
            main_text=", ".join(list(consistently_sold_products["Description"])),
            supporting_text="products with the most recurring sales",
            icon="warning",
            _type=InfographDataType.WARNING,
        ),
        InfographData(
            label="Highest Quantity Sold",
            main_text=", ".join(list(high_quantity_sold_products["Description"])),
            supporting_text="products with the highest logistic load",
            icon="check_circle",
            _type=InfographDataType.POSITIVE,
        ),
        InfographData(
            label="YTD Revenue",
            main_text=f"${ytd_rev:,.0f}",
            supporting_text="generated this year",
            icon="info",
            _type=InfographDataType.INFO,
        ),
    ]

    executive_info = Infographic(
        title="Insights on Revenue Growth Opportunities",
        description="All information is derived " + "from the past year of data",
        data=executive_info_blocks,
        layout=InfographicOrientation.ROW,
    )

    # Create a table and add it to the right Dashboard Column alongside the Infographic
    table_features = [
        "Description",
        "TotalSalesValue",
        "TotalQuantitySold",
        "MonthsSoldFor",
    ]
    table = Table(
        content=top_products[table_features].head(20),
        title="Top Products Tabular Data",
    )
    card.update_item("Top Products Tabular Data", table)
    card.update_item("Insights on Revenue Growth Opportunities", executive_info)
    card.update_item("Total Sales Revenue by Month", sales_by_month_line)
    card.update_item("Most Orders by Market", buyers_bar)
    card.update_item("Top Revenue-Generating Products", top_products_bar)
    store_interface.update_page(page)


# This step has the user upload the dataset we'll be using
class DataUploadStep(Step):
    def run(self, flow_metadata):
        store_interface = StoreInterface(**flow_metadata)
        page = store_interface.get_page()


        data_source_link = "https://docs.virtualitics.com/hc/en-us/articles/34931990206099#downloads"

        data_source_description = (
            f"You can find the relevant data for this tutorial here: {data_source_link}"
        )
        data_source = DataSource(
            title="Upload e-commerce data here!",
            options=["csv"],
            description=data_source_description,
            required=True,
        )

        data_card = Card(title="Data Upload Card", content=[data_source])
        page.add_card_to_section(data_card, "")

        store_interface.update_page(page)


class KickOffEvent(CustomEvent):
    def callback(self, flow_metadata):
        print(f"Running Kick-off {flow_metadata}")
        return "Executed Kick-off Custom Event"


# Create Step which create an interactive dashboards where users can navigate between two different dashboards using
# tabs.
class SimpleDashboards(Step):
    def run(self, flow_metadata):
        store_interface = StoreInterface(**flow_metadata)
        current_page = store_interface.get_page()

        store_interface.update_progress(
            5, "Loading and processing data (this may take a few minutes)"
        )

        # Get the data from the data upload step, turn it into a predict dataset to persist it for later use
        df = store_interface.get_element_value(
            data_upload_step.name, "Upload e-commerce data here!"
        )
        df["SaleValue"] = df["Quantity"] * df["UnitPrice"]
        df["InvoiceDate"] = pd.to_datetime(df["InvoiceDate"])
        df["InvoiceMonth"] = df["InvoiceDate"].dt.strftime("%Y-%m")
        df["Description"] = df["Description"].apply(lambda x: string.capwords(str(x)))
        store_interface.save_output(df, "Predict Data")

        store_interface.update_progress(25, "Building visualizations")

        # Create first tabs within step page
        # Extract the top products by total sales and create a bar plot
        sales_value_by_product = df.groupby("Description").agg(
            TotalSalesValue=("SaleValue", "sum"),
            TotalQuantitySold=("Quantity", "sum"),
            MonthsSoldFor=("InvoiceMonth", "nunique"),
        )
        top_products = (
            sales_value_by_product.sort_values("TotalSalesValue", ascending=False)
            .head(8)
            .reset_index(level=0)
        )
        
        fig_top_products = px.bar(
            top_products,
            x="Description",
            y="TotalSalesValue",
            title="Top Revenue-Generating Products",
            labels={"Description": "Product Description", "TotalSalesValue": "Total Sales Value"},
            text="TotalSalesValue",
            template="predict_default"
        )
        fig_top_products.update_traces(texttemplate='%{text:.2s}', textposition='outside')
        top_products_bar = PlotlyPlot(fig_top_products)


        # Extract the total sales count by Country and create barplot of most actively buying nations outside of the UK
        country_list = list(df["Country"].unique())
        country_counts = df.Country.value_counts()
        country_counts = pd.DataFrame(
            {"Country": country_counts.index, "Counts": country_counts.values}
        )
        sorted_country_counts = country_counts.sort_values("Counts", ascending=False)
        non_uk_buyers = sorted_country_counts[
            sorted_country_counts["Country"] != "United Kingdom"
        ].head(10)

        country_selection = Dropdown(
            options=country_list,
            selected=non_uk_buyers.Country.unique().tolist(),
            multiselect=True,
            title="Countries",
        )

        date_range_element = DateTimeRange(
            min_range=df["InvoiceDate"].min(),
            max_range=df["InvoiceDate"].max(),
            title="Date Range",
        )

        fig_non_uk_buyers = px.bar(
            non_uk_buyers,
            x="Country",
            y="Counts",
            title="Most Orders by Market",
            labels={"Country": "Country", "Counts": "Order Counts"},
            text="Counts",
            template="predict_default"
        )
        fig_non_uk_buyers.update_traces(texttemplate='%{text}', textposition='outside')
        non_uk_buyers_bar = PlotlyPlot(fig_non_uk_buyers)

        # Calculate the total sales for each month and create a line plot
        sales_by_month = df.groupby("InvoiceMonth").agg(
            TotalSalesValue=("SaleValue", "sum")
        )
        sales_by_month = (
            sales_by_month.sort_values("InvoiceMonth", ascending=True)
            .reset_index(level=0)
            .head(-1)
        )

        # Create the line plot
        sales_by_month_fig = px.line(
            sales_by_month,
            x="InvoiceMonth",
            y="TotalSalesValue",
            title="Total Sales Revenue by Month",
            labels={"InvoiceMonth": "Invoice Month", "TotalSalesValue": "Total Sales Revenue"},
            markers=True,
            template="predict_default"
        )
        sales_by_month_line = PlotlyPlot(sales_by_month_fig)

        # Previous month comparison CustomerIDs
        sales_by_client_month = df.groupby(["CustomerID", "InvoiceMonth"]).agg(
            CurrentMonthSales=("SaleValue", "sum")
        )
        sales_by_client_month = sales_by_client_month.sort_values(
            ["CustomerID", "InvoiceMonth"], ascending=True
        ).reset_index()
        sales_by_client_month["PreviousMonthSales"] = (
            sales_by_client_month.sort_values(["CustomerID", "InvoiceMonth"])
            .groupby("CustomerID")["CurrentMonthSales"]
            .shift(1)
        )

        # Get current month recurring clients
        current_sales_by_client_month = sales_by_client_month[
            sales_by_client_month.InvoiceMonth
            == sales_by_client_month.InvoiceMonth.max()
        ]
        current_clients_count = current_sales_by_client_month.shape[0]
        recurring_customers_ratio = (
            current_sales_by_client_month["PreviousMonthSales"].count()
            / current_clients_count
            * 100
        )

        # Get products sold consistently over time
        consistently_sold_products = (
            sales_value_by_product.sort_values("MonthsSoldFor", ascending=False)
            .head(2)
            .reset_index(level=0)
        )

        # Get products sold in highest quantities
        high_quantity_sold_products = (
            sales_value_by_product.sort_values("TotalQuantitySold", ascending=False)
            .head(2)
            .reset_index(level=0)
        )

        # Calculate the total sales value for this calendar year
        current_year = df["InvoiceDate"].dt.year.max()
        ytd_rev = df[df["InvoiceDate"].dt.year == current_year]["SaleValue"].sum()

        # Create the InfographData elements needed and attach them to the Infographic
        executive_info_blocks = [
            InfographData(
                label="Recurring Customers",
                main_text=f"{int(recurring_customers_ratio)}%",
                supporting_text="of customers are retained month-to-month",
                icon="check_circle",
                _type=InfographDataType.POSITIVE,
            ),
            InfographData(
                label="Consistently Sold Products",
                main_text=", ".join(list(consistently_sold_products["Description"])),
                supporting_text="products with the most recurring sales",
                icon="warning",
                _type=InfographDataType.WARNING,
            ),
            InfographData(
                label="Highest Quantity Sold",
                main_text=", ".join(list(high_quantity_sold_products["Description"])),
                supporting_text="products with the highest logistic load",
                icon="check_circle",
                _type=InfographDataType.POSITIVE,
            ),
            InfographData(
                label="YTD Revenue",
                main_text=f"${ytd_rev:,.0f}",
                supporting_text="generated this year",
                icon="info",
                _type=InfographDataType.INFO,
            ),
        ]

        executive_info = Infographic(
            title="Insights on Revenue Growth Opportunities",
            description="All information is derived " + "from the past year of data",
            data=executive_info_blocks,
            layout=InfographicOrientation.ROW,
        )

        # Create a table and add it to the right Dashboard Column alongside the Infographic
        table_features = [
            "Description",
            "TotalSalesValue",
            "TotalQuantitySold",
            "MonthsSoldFor",
        ]
        table = Table(
            content=top_products[table_features].head(20),
            title="Top Products Tabular Data",
        )
        top_row = Row([Column([top_products_bar], ratio=[0.7]), Column([non_uk_buyers_bar], ratio=[0.7])])
        bottom_row = Row([table, Column([sales_by_month_line], ratio=[0.7])])

        # Create a Dashboard using the previously created columns, add it to the current page and persist the changes
        executive_card = Card(
            title="Executive Dashboard Plots",
            content=[executive_info, top_row, bottom_row],
            show_title=False,
            filters=[date_range_element, country_selection],
            filter_update=ExecutiveUpdater,
        )

        current_page.add_card_to_section(executive_card, "Simple Dash")

        # Create second tab within the page, in this dashboard are showing different operational plots.
        # Define months for recent month growth metrics
        this_month = df["InvoiceMonth"].unique()[
            -2
        ]  # Skipping most recent month due to incomplete data
        last_month = df["InvoiceMonth"].unique()[-3]
        this_month_dataset = df[df["InvoiceMonth"] == this_month]
        last_month_dataset = df[df["InvoiceMonth"] == last_month]

        # Calculate monthly customer growth
        customers_this_month = df[df["InvoiceMonth"] == this_month][
            "CustomerID"
        ].nunique()

        customers_last_month = df[df["InvoiceMonth"] == last_month][
            "CustomerID"
        ].nunique()

        monthly_customer_growth = (customers_this_month / customers_last_month) - 1
        f"{round(monthly_customer_growth*100, 2)}%"

        # Identify fastest shrinking regions
        this_month_national_sales = this_month_dataset.groupby("Country")[
            "SaleValue"
        ].sum()
        last_month_national_sales = last_month_dataset.groupby("Country")[
            "SaleValue"
        ].sum()
        national_sales_growth = (
            this_month_national_sales / last_month_national_sales - 1
        )
        shrinking_markets = national_sales_growth.sort_values().head(3).index.tolist()
        shrinking_markets = ", ".join(shrinking_markets)

        # Identify fastest growing products
        this_month_national_sales = this_month_dataset.groupby("Description")[
            "SaleValue"
        ].sum()
        last_month_national_sales = last_month_dataset.groupby("Description")[
            "SaleValue"
        ].sum()
        product_sales_growth = this_month_national_sales / last_month_national_sales - 1
        growing_products = (
            product_sales_growth.sort_values(ascending=False).head(3).index.tolist()
        )
        growing_products = ", ".join(growing_products)

        # Calculate YTD revenue
        monthly_rev = df[df["InvoiceMonth"] == last_month]["SaleValue"].sum()

        store_interface.update_progress(25, "Building visuals")

        # Create the InfographData elements needed and attach them to the Infographic
        operational_info_blocks = [
            InfographData(
                label="Monthly Customer Growth",
                main_text=f"{round(monthly_customer_growth*100, 2)}%",
                supporting_text="change in customer count this month",
                icon="check_circle",
                _type=InfographDataType.POSITIVE,
            ),
            InfographData(
                label="Shrinking Markets",
                main_text=f"{shrinking_markets}",
                supporting_text="are the 3 fastest shrinking markets",
                icon="warning",
                _type=InfographDataType.WARNING,
            ),
            InfographData(
                label="Fastest Growing Products",
                main_text=f"{growing_products}",
                supporting_text="are the 3 fastest growing products",
                icon="check_circle",
                _type=InfographDataType.POSITIVE,
            ),
            InfographData(
                label="Revenue",
                main_text=f"${round(monthly_rev,2):,.0f}",
                supporting_text="generated in the last month",
                icon="info",
                _type=InfographDataType.INFO,
            ),
        ]
        operational_info = Infographic(
            title="Insights on Recent Performance",
            description="",
            data=operational_info_blocks,
            layout=InfographicOrientation.ROW,
        )

        # Extract relevant data and create the plots to display
        national_sales_growth = pd.DataFrame(
            {
                "Country": national_sales_growth.index,
                "Monthly Revenue Growth": national_sales_growth.values,
            }
        )
        growing_markets = national_sales_growth.sort_values(
            "Monthly Revenue Growth", ascending=False
        ).iloc[1:8]

        fig_barplot = px.bar(
            growing_markets,
            x="Country",
            y="Monthly Revenue Growth",
            title="Highest Monthly Revenue Growth by Market",
            template="predict_default"
        )
        growing_markets_barplot = PlotlyPlot(fig_barplot)

        # Calculate the total sales for each month and create a line plot
        quantity_by_month = df.groupby("InvoiceMonth").agg(
            LogisticsLoad=("Quantity", "sum")
        )
        quantity_by_month = (
            quantity_by_month.sort_values("InvoiceMonth", ascending=True)
            .reset_index(level=0)
            .head(-1)
        )

        fig_line = px.line(
            quantity_by_month,
            x="InvoiceMonth",
            y="LogisticsLoad",
            title="Logistics Load",
            labels={"InvoiceMonth" : "Invoice Month", "LogisticsLoad" : "Total Quantity Shipped"},
            markers=True,
            template="predict_default"
        )
        logistics_lineplot = PlotlyPlot(fig_line)

        # Create a Recommendation Infographic
        rec = InfographData(
            label="Recommended Action",
            main_text="**Kick-off segmentation based marketing campaign?**",
            supporting_text="Additional text can be placed here",
            icon="info",
            _type=InfographDataType.POSITIVE,
        )

        kick_off_event = KickOffEvent(
            "Kick-off", "Would you like to take the recommended action?"
        )
        info_rec = Infographic(
            title="Insights and Recommended Action",
            description="Suggested action based on model predictions",
            recommendation=[rec],
            layout=InfographicOrientation.ROW,
            event=kick_off_event,
        )

        # Create a Dashboard and add it to the current page and persist the changes
        operational_card = Card(
            title="Operational Dashboard Plots",
            content=[
                operational_info,
                Row([Column([growing_markets_barplot], ratio=[0.7]), Column([logistics_lineplot], ratio=[0.7])]),
                info_rec,
            ],
        )
        current_page.add_card_to_section(operational_card, "Simple Dash")
        store_interface.update_page(current_page)


# Create a Step with Explore's connection and images
class ExploreConnectionStep(Step):
    def run(self, flow_metadata, pyvip_client=None):
        # get store_interface
        store_interface = StoreInterface(**flow_metadata)
        current_page = store_interface.get_page()
        store_interface.update_progress(25, "Extracting Data")

        # Get data from previous saved Asset
        data = store_interface.get_input("Predict Data")

        pyvip_client.load_data(data, "Ecommerce Data")
        smart_mapping = pyvip_client.smart_mapping(
            data["SaleValue"],
            exclude=["InvoiceNo", "Description", "CustomerID"],
            path=True,
        )
        img_obj_smart = PILImage.open(BytesIO(smart_mapping["image_bytes"]))
        fig_mapping = Image(
            content=img_obj_smart,
            title="Smart Mapping",
            description="",
            size=ImageSize.LARGE,
        )

        line_plot = pyvip_client.plot(
            plot_type="line",
            x="InvoiceMonth",
            y="SaleValue",
            color="StockCode",
            size_scale=1.0,
            path=True,
        )

        img_obj_line = PILImage.open(BytesIO(line_plot["image_bytes"]))
        fig_line = Image(
            content=img_obj_line,
            title="Line Plot",
            description="",
            size=ImageSize.LARGE,
        )

        hist_plot = pyvip_client.hist(
            x="InvoiceMonth",
            y="SaleValue",
            color="StockCode",
            size_scale=2.0,
            path=True,
        )
        img_obj_hist = PILImage.open(BytesIO(hist_plot["image_bytes"]))
        fig_hist = Image(
            content=img_obj_hist,
            title="Histogram",
            description="",
            size=ImageSize.LARGE,
        )

        pyvip_client.plot(
            plot_type="line",
            x="InvoiceMonth",
            y="SaleValue",
            z="Country",
            color="StockCode",
            size_scale=1.0,
            path=True,
        )
        plot_card = Card(
            title="What is Explore?",
            content=[fig_mapping, Row([fig_line, fig_hist], ratio=[0.5, 0.5])],
            description="Virtualitics’ Intelligent Exploration uses AI"
            + " and ML models to analyze rich, multi-dimensional data "
            + "and quickly finds the patterns on data.",
        )
        current_page.add_card_to_section(plot_card, section_title="Explore section")
        store_interface.update_page(current_page)


# Instantiate flow and assign image
persona_dashboards = App(
    name="Persona Based Dashboard - User Deployed",
    description="New to writing apps in VAIP? Check out this tutorial app.",
    image_path="https://predict-tutorials.s3-us-gov-west-1.amazonaws.com/persona_based_dashboards_tile.jpeg",
)

# Build data upload step
data_upload_page = Page(title="Data Upload", sections=[Section("", [])])
data_upload_step = DataUploadStep(
    title="Data Upload",
    description="",
    parent="Uploading the data",
    type=StepType.INPUT,
    page=data_upload_page,
)

dashboard_content = Section("Simple Dash", [])
simple_dashboards_page = Page("Using Dashboards in Predict", [dashboard_content])
simple_dashboards_step = SimpleDashboards(
    "Plots and Tables",
    "Plots and Table",
    "Dashboard",
    StepType.DASHBOARD,
    simple_dashboards_page,
)

data_explore_content = Section("Explore section", [])
data_explore_page = Page("Explore Connection", [data_explore_content])
data_explore_step = ExploreConnectionStep(
    "Plot Explore",
    "Plot Explore",
    "Explore Analysis",
    StepType.RESULTS,
    data_explore_page,
    uses_pyvip=True,
)

# Chain together the steps of the app
persona_dashboards.chain([data_upload_step, simple_dashboards_step, data_explore_step])
