Source code for lux.vislib.altair.AltairChart

#  Copyright 2019-2020 The Lux Authors.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import re

import lux
import numpy as np
import pandas as pd
from lux.utils.date_utils import compute_date_granularity

import altair as alt


[docs]class AltairChart: """ AltairChart is a representation of a chart. Common utilities for charts that is independent of chart types should go here. See Also -------- altair-viz.github.io """
[docs] def __init__(self, vis): self.vis = vis self.data = vis.data self.tooltip = True # ----- START self.code modification ----- self.code = "" self.width = 160 self.height = 150 self.chart = self.initialize_chart() # self.add_tooltip() self.encode_color() self.add_title() self.apply_default_config()
# ----- END self.code modification ----- def __repr__(self): return f"AltairChart <{str(self.vis)}>"
[docs] def add_tooltip(self): if self.tooltip: self.chart = self.chart.encode(tooltip=list(self.vis.data.columns))
[docs] def apply_default_config(self): self.chart = self.chart.configure_title(fontWeight=500, fontSize=13, font="Helvetica Neue") self.chart = self.chart.configure_axis( titleFontWeight=500, titleFontSize=11, titleFont="Helvetica Neue", labelFontWeight=400, labelFontSize=9, labelFont="Helvetica Neue", labelColor="#505050", ) self.chart = self.chart.configure_legend( titleFontWeight=500, titleFontSize=10, titleFont="Helvetica Neue", labelFontWeight=400, labelFontSize=9, labelFont="Helvetica Neue", ) plotting_scale = lux.config.plotting_scale self.chart = self.chart.properties( width=self.width * plotting_scale, height=self.height * plotting_scale ) self.code += ( "\nchart = chart.configure_title(fontWeight=500,fontSize=13,font='Helvetica Neue')\n" ) self.code += "chart = chart.configure_axis(titleFontWeight=500,titleFontSize=11,titleFont='Helvetica Neue',\n" self.code += "\t\t\t\t\tlabelFontWeight=400,labelFontSize=8,labelFont='Helvetica Neue',labelColor='#505050')\n" self.code += "chart = chart.configure_legend(titleFontWeight=500,titleFontSize=10,titleFont='Helvetica Neue',\n" self.code += "\t\t\t\t\tlabelFontWeight=400,labelFontSize=8,labelFont='Helvetica Neue')\n" self.code += f"chart = chart.properties(width={self.width * plotting_scale},height={self.height * plotting_scale})\n"
[docs] def encode_color(self): color_attr = self.vis.get_attr_by_channel("color") if len(color_attr) == 1: color_attr_name = color_attr[0].attribute color_attr_type = color_attr[0].data_type if color_attr_type == "temporal": timeUnit = compute_date_granularity(self.vis.data[color_attr_name]) self.chart = self.chart.encode( color=alt.Color( str(color_attr_name), type=color_attr_type, timeUnit=timeUnit, title=color_attr_name, ) ) self.code += f"chart = chart.encode(color=alt.Color('{color_attr_name}',type='{color_attr_type}',timeUnit='{timeUnit}',title='{color_attr_name}'))" else: self.chart = self.chart.encode( color=alt.Color(str(color_attr_name), type=color_attr_type) ) self.code += f"chart = chart.encode(color=alt.Color('{color_attr_name}',type='{color_attr_type}'))\n" elif len(color_attr) > 1: raise ValueError( "There should not be more than one attribute specified in the same channel." )
[docs] def add_title(self): chart_title = self.vis.title if chart_title: self.chart = self.chart.encode().properties(title=chart_title) if self.code != "": self.code += f"chart = chart.encode().properties(title = '{chart_title}')"
[docs] def initialize_chart(self): return NotImplemented
[docs] @classmethod def sanitize_dataframe(self, df): from lux.utils.date_utils import is_timedelta64_series, timedelta64_to_float_seconds for attr in df.columns: # Check if dtype is unrecognized by Altair (#247) if str(df[attr].dtype) == "Float64": df[attr] = df[attr].astype(np.float64) # Check for timedelta64[...] dtype if is_timedelta64_series(df[attr]): df[attr] = timedelta64_to_float_seconds(df[attr]) # Altair can not visualize non-string columns # convert all non-string columns in to strings df = df.rename(columns={attr: str(attr)}) return df