# Import classes from the Virtualitics SDK
from virtualitics_sdk import (
    App, Step, StepType, Page, Section, Card, CustomEvent, DataSource, Dropdown, NumericRange, 
    DateTimeRange, TextInput, Table, InfographData, InfographDataType, Infographic, Column, Dataset,
    StoreInterface, create_bar_plot, create_line_plot, create_scatter_plot, Row
    )

# Import external packages
from virtualitics import vip_annotation
import pandas as pd
import collections
import random
import string
from datetime import datetime

DROPDOWN_TITLE = (
    "Dropdown Title: Which stock do you want to do an individual analysis on?"
)
RANGE_TITLE = "Choose: How many companies do you want to show?"
DATE_TITLE = "Date Range Title (Optional): Choose a date range for the analysis"
TEXT_INPUT_TITLE = "Text Field: Type a sentence to get a word count analysis!"


# This step has the user upload the dataset we'll be using
class DataUpload(Step):
    dataset_name = "SP 500 Dataset"

    def run(self, flow_metadata):
        # Get store_interface and then current page and section
        store_interface = StoreInterface(**flow_metadata)
        page = store_interface.get_page()
        section = page.get_section_by_title("")

        # Set section title
        upload_section_title = "Loading Data from Databases, Datastores, and Data Lakes"
        section.title = upload_section_title

        # Create DataUpload Card
        guide_link = "https://docs.virtualitics.com/hc/en-us/articles/36051170263571"
        data_link = "https://docs.virtualitics.com/hc/en-us/articles/34931948678163#downloads"

        upload_subtitle1 = (
            f"To learn more about how to load data from databases, datastores,"
            f" and data lakes, check out this link: {guide_link}"
        )
        upload_subtitle2 = (
            f"You can find the relevant data for this tutorial here: {data_link}"
        )
        upload_subtitle = upload_subtitle1 + "\n \n" + upload_subtitle2

        data_upload_card = Card(
            title="Data Source Title (Optional): Upload the stock ticker data!",
            content=[
                DataSource(
                    title="S&P 500 Dataset",
                    options=["csv"],
                    description=upload_subtitle,
                    show_title=False,
                )
            ],
        )

        # Add card to section, set section title and subtitle, and update page
        page.add_card_to_section(data_upload_card, "")

        store_interface.update_page(page)


# This step allows the user to specify query parameters for our data before analysis
# It references the data uploaded in the previous step to set selection boundaries for the query parameters
class DataQuery(Step):
    def run(self, flow_metadata):
        # Get store_interface and current page
        store_interface = StoreInterface(**flow_metadata)
        page = store_interface.get_page()

        # This step updates the user on the progress of the step as it processes
        # The first parameter designates the percent to which the progress bar is filled
        # The second parameter designates the message that is displayed with the progress bar
        store_interface.update_progress(5, "Creating query selection options")

        # Fetch input data and extract query boundaries for stock name and date
        data = store_interface.get_element_value(
            data_upload_step.name, "S&P 500 Dataset"
        )
        stock_names = sorted(data["Name"].unique())
        min_date = datetime.strptime(data["date"].min(), "%Y-%m-%d")
        max_date = datetime.strptime(data["date"].max(), "%Y-%m-%d")

        store_interface.update_progress(75, "Building input elements")

        # Create dropdown card
        stock_dropdown = Dropdown(
            options=stock_names,
            title=DROPDOWN_TITLE,
            description="Description for dropdown",
            label="Stock Ticker",
            placeholder="Select One",
        )
        stock_dropdown_card = Card("Dropdown", [stock_dropdown], show_title=False)

        # Create numeric range card
        company_count_range = NumericRange(
            min_range=15,
            max_range=500,
            min_selection=15,
            max_selection=500,
            title=RANGE_TITLE,
            description="Description for numeric range (Optional)",
        )
        company_count_card = Card(
            "Numeric Range", [company_count_range], show_title=False
        )

        # Create date range card
        date_range = DateTimeRange(
            min_range=min_date,
            max_range=max_date,
            title=DATE_TITLE,
            description="Description for date range (Optional)",
        )
        date_range_card = Card("Data Range", [date_range], show_title=False)

        # Create text input card
        text_input = TextInput(
            title=TEXT_INPUT_TITLE,
            description="Description for Text Field (Optional)",
            show_title=False,
            label="Sentence",
            placeholder="Type a sentence...",
        )
        text_input_card = Card("Text Input", [text_input], show_title=False)

        # Add cards to section and update page
        store_interface.update_progress(95, "Updating page")
        page.add_card_to_section(stock_dropdown_card, "")
        page.add_card_to_section(company_count_card, "")
        page.add_card_to_section(date_range_card, "")
        page.add_card_to_section(text_input_card, "")

        store_interface.update_page(page)


# This step creates various plots after applying the query specified in the previous step to our data
class DataVisualization(Step):
    def run(self, flow_metadata):
        # Get store_interface and current page
        store_interface = StoreInterface(**flow_metadata)
        page = store_interface.get_page()

        # Fetch data from upload step
        data = store_interface.get_element_value(
            data_upload_step.name, "S&P 500 Dataset"
        )

        # Fetch query parameters from previous step
        stock_analyzing = store_interface.get_element_value(
            data_query_step.name, DROPDOWN_TITLE
        )
        stock_count_range = store_interface.get_element_value(
            data_query_step.name, RANGE_TITLE
        )
        stock_count = round(stock_count_range["max"]) - round(stock_count_range["min"])

        date_range = store_interface.get_element_value(data_query_step.name, DATE_TITLE)
        min_date = date_range["min"]
        max_date = date_range["max"]

        # Begin preprocessing
        store_interface.update_progress(5, "Preprocessing the Dataset")

        # Force datetime dtype and clean up column names
        data["date"] = pd.to_datetime(data.date)
        data["ticker"] = data["Name"].astype("string")
        col_order = ["date", "ticker", "open", "high", "low", "close", "volume"]
        data = data[col_order]

        # Filter data to contain user-specifed number of stocks
        stocks_subset = random.sample(list(data["ticker"].unique()), stock_count)
        if stock_analyzing not in stocks_subset:
            stocks_subset = stocks_subset[:-1] + [stock_analyzing]

        data = data[data["ticker"].isin(stocks_subset)]

        # Filter data to user-specified date range
        data = data[
            (data["date"] >= datetime.strptime(min_date[:10], "%Y-%m-%d"))
            & (data["date"] <= datetime.strptime(max_date[:10], "%Y-%m-%d"))
        ]

        # Abstract a percent change column and change direction column for use in plotting
        data["pct_change"] = ((data["close"] - data["open"]) / data["open"]) * 100
        data["change_type"] = data["pct_change"].apply(
            lambda x: "gain" if x > 0 else "none" if x == 0 else "loss"
        )

        # Save preprocessed data as an output for future use
        store_interface.save_output(data, "Preprocessed Data")
        data_ticker = data.loc[data["ticker"] == stock_analyzing]
        # Display data table
        data_table = Table(
            data_ticker, title="Table of Chosen Ticker", show_description=False
        )

        table_card = Card(
            title="Card with Table of Chosen Ticker",
            content=[data_table],
            description="The Table below displays "
            + "information related to the Chosen Ticker. "
            + "Please note that in the Predict Section,"
            + " no more than 1000 rows can be displayed.",
        )

        page.add_card_to_section(table_card, "")

        # Create scatter plot
        scatter_plot = create_scatter_plot(
            data[data.ticker == stock_analyzing],
            "volume",
            "pct_change",
            color_by="change_type",
            plot_title=f"Scatter Plot: {stock_analyzing} Volume vs % Change in Daily Price",
        )

        # Prepare month-aggregated price mean data for line plot
        df_agg_indiv = pd.DataFrame(
            data[data.ticker == stock_analyzing]
            .groupby(pd.Grouper(key="date", freq="M"))["close"]
            .mean()
        )
        df_agg_indiv["date"] = df_agg_indiv.index
        df_agg_indiv["price"] = df_agg_indiv.close

        # Create line plot
        line_plot = create_line_plot(
            df_agg_indiv,
            "date",
            "price",
            plot_title=f"Line Plot: {stock_analyzing} Stock Price Over Time",
        )

        # Prepare year change data for bar plot
        df_start = data[data["date"] == data["date"].min()][
            ["ticker", "open"]
        ].set_index("ticker")
        df_end = data[data["date"] == data["date"].max()][
            ["ticker", "close"]
        ].set_index("ticker")
        df_year = df_start.join(df_end, on="ticker", how="inner")
        df_year["ticker"] = df_year.index
        df_year["pct_change"] = (
            (df_year["close"] - df_year["open"]) / df_year["open"]
        ) * 100

        # Create bar plot
        bar_plot = create_bar_plot(
            df_year.sort_values("pct_change", ascending=False).head(15),
            "ticker",
            "pct_change",
            plot_title="Bar Chart: Top 15 Largest Gainers",
        )

        # Add the plots to a card
        plots_card = Card(
            title="Plots", 
            content=[Row([Column([scatter_plot], ratio=[0.7])]),
                     Row([Column([line_plot], ratio=[0.7])]),
                     Row([Column([bar_plot], ratio=[0.7])]),
                    ]
        )

        # Add card to section and update page
        page.add_card_to_section(plots_card, "")
        store_interface.update_page(page)


# This creates a custom event for us to place in the step defined below
class ExploreNow(CustomEvent):
    def callback(self, flow_metadata, pyvip_client=None):
        print(f"Running ExploreNow {flow_metadata}")

        # Get store_interface and pyvip client
        # Within a custom event we need to create a new pyvip client object
        store_interface = StoreInterface(**flow_metadata)

        # Check to see if we have a valid connection to Explore
        data = store_interface.get_input("Preprocessed Data")

        explore_data = data[
            ["date", "open", "close", "high", "low", "volume", "change_type"]
        ]
        
        pyvip_client.load_data(explore_data, dataset_name=DataUpload.dataset_name)
        # Using loaded data create two plots, one scatter plot and one line plot with different references on x,y and z.
        # Both of them are coloring by change_type.
        pyvip_client.plot(
            plot_type="scatter",
            x="open",
            y="close",
            z="volume",
            color="change_type",
            size_scale=1.0,
        )
        first_stock_annotation = pyvip_client.create_annotation(
            vip_annotation.AnnotationType.MAPPING,
            name="Scatter Plot",
            comment="In this plot are considering previously preprocessed data."
            + "The points are colored by type of change, that could be: GAIN, LOSS or None.",
            isCollapsed=False,
            mappingID=0,
            width=0.3,
            height=0.2,
        )["result"]

        pyvip_client.plot(
            plot_type="line",
            x="date",
            y="high",
            z="low",
            color="change_type",
            size_scale=1.0,
        )
        second_stock_annotation = pyvip_client.create_annotation(
            vip_annotation.AnnotationType.MAPPING,
            name="Line plot",
            comment="In this plot are considering same data of plot 1. We use also the same coloration.",
            mappingID=1,
            isCollapsed=False,
            width=0.3,
            height=0.2,
        )["result"]

        pyvip_client.link_annotation(
            id=first_stock_annotation.id,
            linkLatestMapping=False,
            linkedDatasetName=DataUpload.dataset_name,
            linkedMappingID=1,
        )

        if second_stock_annotation:
            pyvip_client.link_annotation(
                id=second_stock_annotation.id,
                linkLatestMapping=False,
                linkedDatasetName=DataUpload.dataset_name,
                linkedMappingID=0,
            )

        return "Executed ExploreNow Custom Event"


# This step creates a few more elements to further demonstrate the functionality of Predict
class CreateAdditionalElements(Step):
    def run(self, flow_metadata, **kwargs):
        # Get store_interface and current page
        store_interface = StoreInterface(**flow_metadata)
        page = store_interface.get_page()

        # Create recommendation infographic with custom event
        explore_configuration_link = (
            "https://docs.virtualitics.com/hc/en-us/articles/24987782824211"
        )
        recommendation = InfographData(
            label="Recommended Action",
            main_text=f"""Insight Can Go Here: Load your data into Virtualitics Explore! 
            To learn more about configuring your Explore instance, go here: {explore_configuration_link}""",
            supporting_text="Additional supporting text can be added here",
            icon="info",
            _type=InfographDataType.POSITIVE,
        )

        explore_now_description = """If you're connected to Explore, hitting "Submit" will open a visualization of 
        your queried data in Explore."""
        explore_now_event = ExploreNow("Explore Now", explore_now_description)

        recommendation_infograph = Infographic(
            title="",
            description="",
            recommendation=[recommendation],
            event=explore_now_event,
        )

        # Retrieve text input from query step
        text_str = store_interface.get_element_value(
            data_query_step.name, TEXT_INPUT_TITLE
        )

        # Create text analysis infographic
        store_interface.update_progress(5, "Creating infographic")

        count = lambda l1, l2: sum([1 for x in l1 if x in l2])

        words_count = (
            str(len(text_str.split()))
            if len(text_str.split()) > 0
            else "[No text entered]"
        )
        characters_count = len(text_str)
        punctuation_count = count(text_str, set(string.punctuation))
        if len(text_str) > 0:
            most_common_char_value = collections.Counter(
                text_str.replace(" ", "")
            ).most_common(1)[0][0]
        else:
            most_common_char_value = "[No text entered]"

        total_words = InfographData(
            label="Total Words",
            main_text=words_count,
            supporting_text="Total words in uploaded text.",
            icon="error",
            _type=InfographDataType.NEGATIVE,
        )

        total_characters = InfographData(
            label="Total Characters",
            main_text=str(characters_count),
            supporting_text="Total characters contained in the uploaded text.",
            icon="warning",
            _type=InfographDataType.WARNING,
        )

        total_punctuation = InfographData(
            label="Total Punctuation",
            main_text=str(punctuation_count),
            supporting_text="Total punctuation marks contained in the uploaded text.",
            icon="check_circle",
            _type=InfographDataType.POSITIVE,
        )

        most_common_character = InfographData(
            label="Most Common Character",
            main_text=most_common_char_value,
            supporting_text="Was the most common character contained in the uploaded text. ",
            icon="info",
            _type=InfographDataType.INFO,
        )

        infograph_data_parts = [
            total_words,
            total_characters,
            total_punctuation,
            most_common_character,
        ]

        string_analysis_infograph = Infographic(
            title="Infographic (Good for KPIs) - Word count analysis based on the text uploaded in the last step",
            description="In this description, you can provide some basic context for the metrics displayed in the rest\
                of the infographic",
            data=infograph_data_parts,
        )

        # Add cards to section and update page
        # store_interface.update_progress(95, "Updating page")
        store_interface.update_progress(90, "Updating page")
        info_card = Card(
            title="Key Insights",
            content=[recommendation_infograph, string_analysis_infograph],
        )
        page.add_card_to_section(info_card, "")

        store_interface.update_page(page)


# This step takes an output from a previous step and saves as an asset for the user
class SaveAssets(Step):
    def run(self, flow_metadata):
        # Get store interface and current page
        store_interface = StoreInterface(**flow_metadata)
        page = store_interface.get_page()

        store_interface.update_progress(5, "Saving the Dataset Asset")

        # Retrieve preprocessed Dataset from previous step and save as an asset
        data = store_interface.get_input("Preprocessed Data")
        dataset_asset = Dataset(
            dataset=data, label=DataUpload.dataset_name, name=DataUpload.dataset_name
        )
        store_interface.save_asset(dataset_asset)

        # Add text card to
        save_asset_title = "Asset has been Saved!"
        save_asset_subtitle = """Assets allow you to persist results, data, models, and python objects (as a pickle) 
        so that you can use them in other flows.
        In this case, we’ve saved the dataframe we are using for this analysis with the label “sp500_dataset”."""
        save_asset_description = "You can see this asset listed under the Assets tab in the menu to the left!"
        page.add_content_to_section(
            elems=[],
            section_title="",
            card_title=save_asset_title,
            card_subtitle=save_asset_subtitle,
            card_description=save_asset_description,
        )

        # Add cards to section and update page
        store_interface.update_progress(95, "Updating page")
        store_interface.update_page(page)


# Instantiate app and assign image
hello_world = App(
    name="Hello World! - User Deployed",
    description="A basic introduction to the design and function of apps.",
    image_path="https://predict-tutorials.s3-us-gov-west-1.amazonaws.com/hello_world_tile.jpeg",
)

# Build data upload step
data_upload_section = Section("", [])
data_upload_page = Page("Getting Data Into the Platform", [data_upload_section])
data_upload_step = DataUpload(
    title="Data Upload",
    description="Upload the S&P 500 data.",
    parent="Inputs",
    type=StepType.INPUT,
    page=data_upload_page,
)

# Build query step
data_query_section = Section("", [])
data_query_page = Page("Additional Inputs and Parameters", [data_query_section])
data_query_step = DataQuery(
    title="Additional Inputs",
    description="Build data query.",
    parent="Inputs",
    type=StepType.INPUT,
    page=data_query_page,
)

# Build data visualization step
data_visualization_content = Section("", [])
data_visualization_page = Page(
    "Displaying Analytics through Page Elements",
    [data_visualization_content],
)
data_visualization_step = DataVisualization(
    title="Displaying Analytics through Page Elements",
    description="Show time-series of Dataset and perform preprocessing steps.",
    parent="Analytics and Page Elements",
    type=StepType.RESULTS,
    page=data_visualization_page,
)

# Build additional elements step
additional_elements_content = Section("", [])
additional_elements_page = Page(
    "Displaying Analytics through Page Elements",
    [data_visualization_content],
)
additional_elements_step = CreateAdditionalElements(
    title="Additional Elements and Events",
    description="Create additional elements for demonstration.",
    parent="Analytics and Page Elements",
    type=StepType.RESULTS,
    page=additional_elements_page,
)

# Build asset saving step
save_assets_content = Section("", [])
save_assets_page = Page("Saving Assets", [save_assets_content])
save_assets_step = SaveAssets(
    title="Saving Assets",
    description="Read data from prior step and save as asset.",
    parent="Analytics and Page Elements",
    type=StepType.INPUT,
    page=save_assets_page,
)

# Chain together steps of app
hello_world.chain(
    [
        data_upload_step,
        data_query_step,
        data_visualization_step,
        additional_elements_step,
        save_assets_step,
    ]
)
