# Import Virtualitics AI Application internal elements
from virtualitics_sdk import (
    App, Step, StepType, Card, Page, Section, StoreInterface, Dataset, Model, PlotlyPlot,
    InfographData, Infographic, InfographicOrientation, CustomEvent, Table, XAIDashboard, 
    Column, Row, InfographDataType, Dataset, AssetType, DataEncoding, Explainer
)

# Import external libraries
import pandas as pd
import plotly.express as px
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor


# This step contains a dashboard with bar plots, a table and two infographics
class InvestmentDashboardStep(Step):
    def run(self, flow_metadata):
        # Get the store interface and the current page
        store_interface = StoreInterface(**flow_metadata)
        current_page = store_interface.get_page()

        store_interface.update_progress(5, "Processing Data")

        # Get the data uploaded in the previous step
        data = store_interface.get_asset("house_prices", type=AssetType.DATASET).object

        # One-hot encode categoriacal features and split dataset in train and test sets
        categorical_cols = [
            "LotShape",
            "LandContour",
            "LotConfig",
            "LandSlope",
            "Neighborhood",
            "Condition1",
            "BldgType",
        ]
        target_col = "SalePrice"
        
        dummies = pd.get_dummies(data[categorical_cols])
        ohe_features = [x for x in dummies.columns]
        data = data.join(dummies)
        train_data, test_data = train_test_split(data, test_size=0.3, random_state=12)

        # Store processed data for use in successive steps
        store_interface.save_asset(
            Dataset(train_data, label="train_dataset", name="train_dataset")
        )
        store_interface.save_asset(
            Dataset(test_data, label="test_dataset", name="test_dataset")
        )

        store_interface.update_progress(30, "One Hot Encoding Categorical Data")

        # Build model
        store_interface.update_progress(40, "Creating Model")
        xgb = Model(
            model=XGBRegressor(
                learning_rate=0.1,
                n_jobs=-1,
                n_estimators=100,
                max_depth=5,
                eval_metric="mae",
                random_state=42,
            ),
            label="xgb",
            name="house sale price predictor",
        )

        # Train model and save as asset
        xgb.fit(train_data[ohe_features], train_data[target_col])
        store_interface.save_asset(xgb)

        # Apply trained model to test set
        test_target = test_data[target_col]
        test_pred = xgb.predict(test_data[ohe_features])
        test_results = pd.DataFrame(
            {"ActualPrice": test_target, "PredictedPrice": test_pred}
        )

        # We will define houses as flippable if their predicted price is at least 20% greater than what's listed
        flippable_houses = test_results[
            (test_results["PredictedPrice"] / test_results["ActualPrice"]) > 1.2
        ]

        # Get average sale price for flippable houses
        capital_required = int(flippable_houses["ActualPrice"].mean())

        # Get the number of underpriced houses: the number of houses for which predicted price is greater than acutal
        underpriced_house_count = test_results[
            test_results["PredictedPrice"] > test_results["ActualPrice"]
        ].shape[0]

        # Get the number of flippable houses: the number of houses for which predicted price is at least 20% greater than acutal
        flippable_house_count = flippable_houses.shape[0]

        # Compute expected return with previous year's average sale price
        expected_sale_price = flippable_houses["PredictedPrice"].mean()
        actual_listed_price = flippable_houses["ActualPrice"].mean()
        expected_return = (
            expected_sale_price - actual_listed_price
        ) / actual_listed_price

        store_interface.update_progress(50, "Building Visualizations")

        # Save outputs for use in succesive steps
        store_interface.save_output(categorical_cols, "categorical_cols")
        store_interface.save_output(ohe_features, "ohe_features")

        # Create the InfographData elements needed and attach them to the Infographic
        info_blocks = [
            InfographData(
                label="Capital Required",
                main_text=f"${capital_required:,.0f}",
                supporting_text="Average listed price of flippable houses",
                icon="error",
                _type=InfographDataType.NEGATIVE,
            ),
            InfographData(
                label="Flippable Houses",
                main_text=f"{flippable_house_count}",
                supporting_text="Number of flippable houses in the area",
                icon="warning",
                _type=InfographDataType.WARNING,
            ),
            InfographData(
                label="Underpriced Houses",
                main_text=f"{underpriced_house_count}",
                supporting_text="Number of underpriced houses in the area",
                icon="check_circle",
                _type=InfographDataType.POSITIVE,
            ),
            InfographData(
                label="Expected Returns for Flippable Houses",
                main_text=f"{round(expected_return*100, 2)}%",
                supporting_text="Average expected return on flippable houses",
                icon="info",
                _type=InfographDataType.INFO,
            ),
        ]
        info = Infographic(
            title="Insights about Biggest Revenue Growth Opportunities",
            description="There is still enough time to place orders"
            + " for the displayed houses.",
            data=info_blocks,
            layout=InfographicOrientation.ROW,
        )

        # Create a table view of the data and add it to the a Dashboard Column alongside the Infographic
        table_features = ["Neighborhood", "YrSold", "SaleCondition", "SalePrice"]
        table = Table(
            data[table_features].head(50), title="House Sales Data", data_grid=True, editable=False
        )

        # Create a Recommendation Infograph with an Action Button
        rec = InfographData(
            label="Recommended Action",
            main_text="**Contact agent for underpriced and flippable "
            + "house opportunities.**",
            supporting_text="Additional text can be placed here",
            icon="info",
            _type=InfographDataType.POSITIVE,
        )
        contact_event = ContactEvent(
            "Contact", "Would you like to take the recommended action?"
        )
        info_rec = Infographic(
            title="Insights and Recommended Action",
            description="Suggested action based on model predictions",
            recommendation=[rec],
            layout=InfographicOrientation.ROW,
            event=contact_event,
        )

        bin_edges = [
            0,
            50000,
            100000,
            150000,
            200000,
            250000,
            300000,
            350000,
            400000,
            450000,
            500000,
        ]
        bin_labels = [
            '0-49.9k', '50-99.9k', '100-149.9k', '150-199.9k',
            '200-249.9k', '250-299.9k', '300-349.9k',
            '350-399.9k', '400-449.9k', '450-499.9k'
        ]

        data['PriceRange'] = pd.cut(data[target_col], bins=bin_edges, labels=bin_labels, include_lowest=True, right=False)

        # Price distribution histogram
        fig = px.histogram(
        data_frame=data, 
        x=target_col, 
        color='PriceRange',
        title="Distribution of Housing Prices", 
        nbins=len(bin_edges)-1,
        template="predict_default",
        )
        fig.update_traces(hovertemplate='Price Range: %{x}<br>Number of Houses: %{y}<extra></extra>',
                  selector=dict(type="histogram"))
        fig.update_layout(
            xaxis_title="Sale Price",
            yaxis_title="Number of Houses",
            xaxis=dict(tickprefix='$'),
            bargap=0.1,
            showlegend=False
        )
        histogram = PlotlyPlot(fig=fig, title="Distribution of Housing Prices", description="View the number of houses sold across various price ranges to identify popular market segments and investment opportunities.")

        # Abstract a quarter sold feature and plot transactions by quarter
        data["QtrSold"] = data["MoSold"].apply(lambda x: f"Q{(x-1)//3+1}")
        quarter_sold_counts = data.QtrSold.value_counts().sort_index()
        quarter_sold_counts = pd.DataFrame(
            {"QtrSold": quarter_sold_counts.index, "Counts": quarter_sold_counts.values}
        )
        fig2 = px.bar(
        data_frame=quarter_sold_counts,
        x="QtrSold",
        y="Counts",
        title="Houses Sold by Quarter",
        template="predict_default",
        color="QtrSold",
        )
        fig2.update_traces(hovertemplate='Quarter Sold: %{x}<br>Number Sold: %{y}<extra></extra>',
                  selector=dict(type="bar"))
        fig2.update_layout(
            xaxis_title="Quarter Sold",
            yaxis_title="Number of Houses Sold",
            legend_title_text="Quarter"
        )
        bar_quarter = PlotlyPlot(fig=fig2, title="Houses Sold by Quarter", description="Examine the number of houses sold each quarter to spot seasonal trends and optimize investment timing based on market activity.")

        # Create a Dashboard using rows and columns, add it and our reccomendation to the page and persist it
        top_row = Row([info, Column([histogram], ratio=[0.75])], ratio=[0.2, 0.8])
        bottom_row = Row([table, Column([bar_quarter], ratio=[0.75])])

        data_link = "https://docs.virtualitics.com/hc/en-us/articles/34932030275347"

        element_card = Card(
            title="Insights",
            description=f"You can find the relevant data for this tutorial here: {data_link}",
            content=[top_row, bottom_row, info_rec],
            show_title=False,
            show_comments=True, 
            show_export=True, 
            show_share=True
        )
    
        current_page.add_card_to_section(element_card, "")
        store_interface.update_progress(90, "Updating Page")
        store_interface.update_page(current_page)


# This callback function is used to attach an action to the Recommendation Infographic, here a really simple example is shown
# but it can perform actions such as fetching data from an API or ping a web server to set a recurrent supply order
class ContactEvent(CustomEvent):
    def callback(self, flow_metadata):
        print(f"Running Contact {flow_metadata}")
        return "Executed Contact Custom Event"


# This step consists of a Scenario Planning Tool and an Explainer Dashboard
class InteractiveScenarioPlanningStep(Step):
    def run(self, flow_metadata):
        # Get the store interface and the current page
        store_interface = StoreInterface(**flow_metadata)
        current_page = store_interface.get_page()

        store_interface.update_progress(5, "Retrieving assets and inputs")

        # Retrieve required assets and outputs from last step
        train_dataset = store_interface.get_asset(
            "train_dataset", type=AssetType.DATASET
        )
        train_data = (
            train_dataset.object
        )  # Extracts pandas dataframe from dataset object
        test_dataset = store_interface.get_asset("test_dataset", type=AssetType.DATASET)
        test_data = test_dataset.object  # Extracts pandas dataframe from dataset object
        xgb = store_interface.get_asset("xgb", type=AssetType.MODEL)
        categorical_cols = store_interface.get_input("categorical_cols")
        ohe_features = store_interface.get_input("ohe_features")

        store_interface.update_progress(15, "Building Explainer")

        # Create explainer training data asset
        explain_dataset = Dataset(
            train_data[ohe_features],
            label="explain_dataset",
            name="explainer modeling set",
            categorical_cols=categorical_cols,
            encoding=DataEncoding.ONE_HOT,
        )
        store_interface.save_asset(explain_dataset)
        print("DEFINE TYPE OF MODEL" + str(xgb) + str(xgb.type))
        # Create explainer for model
        explainer = Explainer(
            model=xgb,
            training_data=explain_dataset,
            output_names=["SalePrice"],
            mode="regression",
            label="explainer",
            name="ensemble model",
            use_shap=True,
            use_lime=False,
        )
        store_interface.save_asset(explainer)

        store_interface.update_progress(30, "Processing Data")

        # Specify sample instances to use with Explainer
        flippable = test_data[
            test_data.index == 194
        ]  # Select a sample that we defined as "flippable" in the previous step
        explain_instances = pd.concat(
            [flippable]
        )  # More than one sample instance can be used
        titles = [
            "This House's Predicted Price is 30% Higher " + "than Its Listed Price"
        ]
        plots = explainer.explain(
            explain_instances[ohe_features],
            method="manual",
            titles=titles,
            expected_title="Avg. Listed Price",
            predicted_title="Predicted Price",
            return_as="plots",
        )

        # Generate predictions for test set
        test_data["PredictedPrice"] = xgb.predict(test_data[ohe_features])
        test_data["PredictedPrice"] = test_data["PredictedPrice"].clip(lower=0)
        # Turn the result into a dataset asset and persist it
        prediction_dataset = Dataset(
            test_data[ohe_features + ["PredictedPrice"]],
            label="prediction_dataset",
            name="explainer test set",
            categorical_cols=categorical_cols,
            encoding=DataEncoding.ONE_HOT,
        )
        store_interface.save_asset(prediction_dataset)

        # Create the Explainer Dashboard bounds
        train_mins = train_data[ohe_features].min()
        train_maxs = train_data[ohe_features].max()
        bounds = {
            key: [train_mins[key], train_maxs[key]]
            for key in train_data[ohe_features].columns
        }

        store_interface.update_progress(50, "Building XAI Dashboard")

        # Create the Explainer Dashboard
        store_interface.update_progress(
            70, "Creating Scenario Planning" + "Tool Dashboard"
        )
        xai_dash = XAIDashboard(
            xgb,
            explainer,
            prediction_dataset,
            "LandContour",
            "PredictedPrice",
            "PredictedPrice",
            title="",
            bounds=bounds,
            description=XAIDashboard.xai_dashboard_description(),
            train_data=explain_dataset,
            expected_title="Avg. Listed Price",
            predicted_title="Predicted Price",
            encoding=DataEncoding.ONE_HOT,
        )
        current_page.add_content_to_section(xai_dash, "")

        # Updating Progress and add plots to page
        xai_plots_card = Card(
            title="Explanation of a Flippable House",
            content=plots,
        )
        current_page.add_card_to_section(xai_plots_card, "")

        store_interface.update_progress(90, "Updating page")

        store_interface.update_page(current_page)


predictive_modeling = App(
    name="Predictive Modeling Tutorial - User Deployed",
    description="New to writing Virtualitics Apps? Check out this Predictive Modeling tutorial app.",
    image_path="https://predict-tutorials.s3-us-gov-west-1.amazonaws.com/predictive_modeling_tile.jpeg"
)

# Build investment dashboard step
data_step_page = Page(
    title="Investment Dashboard",
    sections=[Section("", [])],
)
executive_dashboard_step = InvestmentDashboardStep(
    title="Investment Dashboard",
    description="",
    parent="Data & Visualizations",
    type=StepType.DASHBOARD,
    page=data_step_page,
)

# Build scenario planning dashboard step
sp_step_page = Page(
    title="Scenario Planning - Flipping Homes",
    sections=[Section("", [])],
)
scenario_planning_step = InteractiveScenarioPlanningStep(
    title="Scenario Planning",
    description="",
    parent="Scenario Planning",
    type=StepType.RESULTS,
    page=sp_step_page,
)

# Chain together the steps of the app
predictive_modeling.chain([executive_dashboard_step, scenario_planning_step])