## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importsysfromcollectionsimportCounterfromtypingimportList,Optional,Type,Union,no_type_check,overload,TYPE_CHECKINGfromwarningsimportcatch_warnings,simplefilter,warnfrompyspark.rddimport_load_from_socketfrompyspark.sql.pandas.serializersimportArrowCollectSerializerfrompyspark.sql.typesimport(IntegralType,ByteType,ShortType,IntegerType,LongType,FloatType,DoubleType,BooleanType,MapType,TimestampType,TimestampNTZType,DayTimeIntervalType,StructType,DataType,)frompyspark.sql.utilsimportis_timestamp_ntz_preferredfrompyspark.traceback_utilsimportSCCallSiteSyncifTYPE_CHECKING:importnumpyasnpimportpyarrowaspafrompy4j.java_gatewayimportJavaObjectfrompyspark.sql.pandas._typingimportDataFrameLikeasPandasDataFrameLikefrompyspark.sqlimportDataFrameclassPandasConversionMixin:""" Mix-in for the conversion from Spark to pandas. Currently, only :class:`DataFrame` can use this class. """deftoPandas(self)->"PandasDataFrameLike":""" Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. .. versionadded:: 1.3.0 .. versionchanged:: 3.4.0 Supports Spark Connect. Notes ----- This method should only be used if the resulting Pandas ``pandas.DataFrame`` is expected to be small, as all the data is loaded into the driver's memory. Usage with ``spark.sql.execution.arrow.pyspark.enabled=True`` is experimental. Examples -------- >>> df.toPandas() # doctest: +SKIP age name 0 2 Alice 1 5 Bob """frompyspark.sql.dataframeimportDataFrameassertisinstance(self,DataFrame)frompyspark.sql.pandas.utilsimportrequire_minimum_pandas_versionrequire_minimum_pandas_version()importnumpyasnpimportpandasaspdfrompandas.core.dtypes.commonimportis_timedelta64_dtypejconf=self.sparkSession._jconftimezone=jconf.sessionLocalTimeZone()ifjconf.arrowPySparkEnabled():use_arrow=Truetry:frompyspark.sql.pandas.typesimportto_arrow_schemafrompyspark.sql.pandas.utilsimportrequire_minimum_pyarrow_versionrequire_minimum_pyarrow_version()to_arrow_schema(self.schema)exceptExceptionase:ifjconf.arrowPySparkFallbackEnabled():msg=("toPandas attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, ""failed by the reason below:\n%s\n""Attempting non-optimization as ""'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to ""true."%str(e))warn(msg)use_arrow=Falseelse:msg=("toPandas attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has ""reached the error below and will not continue because automatic fallback ""with 'spark.sql.execution.arrow.pyspark.fallback.enabled' has been set to ""false.\n%s"%str(e))warn(msg)raise# Try to use Arrow optimization when the schema is supported and the required version# of PyArrow is found, if 'spark.sql.execution.arrow.pyspark.enabled' is enabled.ifuse_arrow:try:frompyspark.sql.pandas.typesimport(_check_series_localize_timestamps,_convert_map_items_to_dict,)importpyarrow# Rename columns to avoid duplicated column names.tmp_column_names=["col_{}".format(i)foriinrange(len(self.columns))]self_destruct=jconf.arrowPySparkSelfDestructEnabled()batches=self.toDF(*tmp_column_names)._collect_as_arrow(split_batches=self_destruct)iflen(batches)>0:table=pyarrow.Table.from_batches(batches)# Ensure only the table has a reference to the batches, so that# self_destruct (if enabled) is effectivedelbatches# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type# values, but we should use datetime.date to match the behavior with when# Arrow optimization is disabled.pandas_options={"date_as_object":True}ifself_destruct:# Configure PyArrow to use as little memory as possible:# self_destruct - free columns as they are converted# split_blocks - create a separate Pandas block for each column# use_threads - convert one column at a timepandas_options.update({"self_destruct":True,"split_blocks":True,"use_threads":False,})pdf=table.to_pandas(**pandas_options)# Rename back to the original column names.pdf.columns=self.columnsforfieldinself.schema:ifisinstance(field.dataType,TimestampType):pdf[field.name]=_check_series_localize_timestamps(pdf[field.name],timezone)elifisinstance(field.dataType,MapType):pdf[field.name]=_convert_map_items_to_dict(pdf[field.name])returnpdfelse:corrected_panda_types={}forindex,fieldinenumerate(self.schema):pandas_type=PandasConversionMixin._to_corrected_pandas_type(field.dataType)corrected_panda_types[tmp_column_names[index]]=(objectifpandas_typeisNoneelsepandas_type)pdf=pd.DataFrame(columns=tmp_column_names).astype(dtype=corrected_panda_types)pdf.columns=self.columnsreturnpdfexceptExceptionase:# We might have to allow fallback here as well but multiple Spark jobs can# be executed. So, simply fail in this case for now.msg=("toPandas attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has ""reached the error below and can not continue. Note that ""'spark.sql.execution.arrow.pyspark.fallback.enabled' does not have an ""effect on failures in the middle of ""computation.\n%s"%str(e))warn(msg)raise# Below is toPandas without Arrow optimization.pdf=pd.DataFrame.from_records(self.collect(),columns=self.columns)column_counter=Counter(self.columns)corrected_dtypes:List[Optional[Type]]=[None]*len(self.schema)forindex,fieldinenumerate(self.schema):# We use `iloc` to access columns with duplicate column names.ifcolumn_counter[field.name]>1:pandas_col=pdf.iloc[:,index]else:pandas_col=pdf[field.name]pandas_type=PandasConversionMixin._to_corrected_pandas_type(field.dataType)# SPARK-21766: if an integer field is nullable and has null values, it can be# inferred by pandas as a float column. If we convert the column with NaN back# to integer type e.g., np.int16, we will hit an exception. So we use the# pandas-inferred float type, rather than the corrected type from the schema# in this case.ifpandas_typeisnotNoneandnot(isinstance(field.dataType,IntegralType)andfield.nullableandpandas_col.isnull().any()):corrected_dtypes[index]=pandas_type# Ensure we fall back to nullable numpy types.ifisinstance(field.dataType,IntegralType)andpandas_col.isnull().any():corrected_dtypes[index]=np.float64ifisinstance(field.dataType,BooleanType)andpandas_col.isnull().any():corrected_dtypes[index]=objectdf=pd.DataFrame()forindex,tinenumerate(corrected_dtypes):column_name=self.schema[index].name# We use `iloc` to access columns with duplicate column names.ifcolumn_counter[column_name]>1:series=pdf.iloc[:,index]else:series=pdf[column_name]# No need to cast for non-empty series for timedelta. The type is already correct.should_check_timedelta=is_timedelta64_dtype(t)andlen(pdf)==0if(tisnotNoneandnotis_timedelta64_dtype(t))orshould_check_timedelta:series=series.astype(t,copy=False)withcatch_warnings():frompandas.errorsimportPerformanceWarningsimplefilter(action="ignore",category=PerformanceWarning)# `insert` API makes copy of data,# we only do it for Series of duplicate column names.# `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work# because `iloc` could return a view or a copy depending by context.ifcolumn_counter[column_name]>1:df.insert(index,column_name,series,allow_duplicates=True)else:df[column_name]=seriesiftimezoneisNone:returndfelse:frompyspark.sql.pandas.typesimport_check_series_convert_timestamps_local_tzforfieldinself.schema:# TODO: handle nested timestamps, such as ArrayType(TimestampType())?ifisinstance(field.dataType,TimestampType):df[field.name]=_check_series_convert_timestamps_local_tz(df[field.name],timezone)returndf@staticmethoddef_to_corrected_pandas_type(dt:DataType)->Optional[Type]:""" When converting Spark SQL records to Pandas `pandas.DataFrame`, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred incorrectly. """importnumpyasnpiftype(dt)==ByteType:returnnp.int8eliftype(dt)==ShortType:returnnp.int16eliftype(dt)==IntegerType:returnnp.int32eliftype(dt)==LongType:returnnp.int64eliftype(dt)==FloatType:returnnp.float32eliftype(dt)==DoubleType:returnnp.float64eliftype(dt)==BooleanType:returnbooleliftype(dt)==TimestampType:returnnp.datetime64eliftype(dt)==TimestampNTZType:returnnp.datetime64eliftype(dt)==DayTimeIntervalType:returnnp.timedelta64else:returnNonedef_collect_as_arrow(self,split_batches:bool=False)->List["pa.RecordBatch"]:""" Returns all records as a list of ArrowRecordBatches, pyarrow must be installed and available on driver and worker Python environments. This is an experimental feature. :param split_batches: split batches such that each column is in its own allocation, so that the selfDestruct optimization is effective; default False. .. note:: Experimental. """frompyspark.sql.dataframeimportDataFrameassertisinstance(self,DataFrame)withSCCallSiteSync(self._sc):(port,auth_secret,jsocket_auth_server,)=self._jdf.collectAsArrowToPython()# Collect list of un-ordered batches where last element is a list of correct order indicestry:batch_stream=_load_from_socket((port,auth_secret),ArrowCollectSerializer())ifsplit_batches:# When spark.sql.execution.arrow.pyspark.selfDestruct.enabled, ensure# each column in each record batch is contained in its own allocation.# Otherwise, selfDestruct does nothing; it frees each column as its# converted, but each column will actually be a list of slices of record# batches, and so no memory is actually freed until all columns are# converted.importpyarrowasparesults=[]forbatch_or_indicesinbatch_stream:ifisinstance(batch_or_indices,pa.RecordBatch):batch_or_indices=pa.RecordBatch.from_arrays([# This call actually reallocates the arraypa.concat_arrays([array])forarrayinbatch_or_indices],schema=batch_or_indices.schema,)results.append(batch_or_indices)else:results=list(batch_stream)finally:# Join serving thread and raise any exceptions from collectAsArrowToPythonjsocket_auth_server.getResult()# Separate RecordBatches from batch order indices in resultsbatches=results[:-1]batch_order=results[-1]# Re-order the batch list using the correct orderreturn[batches[i]foriinbatch_order]classSparkConversionMixin:""" Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession` can use this class. """_jsparkSession:"JavaObject"@overloaddefcreateDataFrame(self,data:"PandasDataFrameLike",samplingRatio:Optional[float]=...)->"DataFrame":...@overloaddefcreateDataFrame(self,data:"PandasDataFrameLike",schema:Union[StructType,str],verifySchema:bool=...,)->"DataFrame":...defcreateDataFrame(# type: ignore[misc]self,data:"PandasDataFrameLike",schema:Optional[Union[StructType,List[str]]]=None,samplingRatio:Optional[float]=None,verifySchema:bool=True,)->"DataFrame":frompyspark.sqlimportSparkSessionassertisinstance(self,SparkSession)frompyspark.sql.pandas.utilsimportrequire_minimum_pandas_versionrequire_minimum_pandas_version()timezone=self._jconf.sessionLocalTimeZone()# If no schema supplied by user then get the names of columns onlyifschemaisNone:schema=[str(x)ifnotisinstance(x,str)elsexforxindata.columns]ifself._jconf.arrowPySparkEnabled()andlen(data)>0:try:returnself._create_from_pandas_with_arrow(data,schema,timezone)exceptExceptionase:ifself._jconf.arrowPySparkFallbackEnabled():msg=("createDataFrame attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, ""failed by the reason below:\n%s\n""Attempting non-optimization as ""'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to ""true."%str(e))warn(msg)else:msg=("createDataFrame attempted Arrow optimization because ""'spark.sql.execution.arrow.pyspark.enabled' is set to true, but has ""reached the error below and will not continue because automatic ""fallback with 'spark.sql.execution.arrow.pyspark.fallback.enabled' ""has been set to false.\n%s"%str(e))warn(msg)raiseconverted_data=self._convert_from_pandas(data,schema,timezone)returnself._create_dataframe(converted_data,schema,samplingRatio,verifySchema)def_convert_from_pandas(self,pdf:"PandasDataFrameLike",schema:Union[StructType,str,List[str]],timezone:str)->List:""" Convert a pandas.DataFrame to list of records that can be used to make a DataFrame Returns ------- list list of records """importpandasaspdfrompyspark.sqlimportSparkSessionassertisinstance(self,SparkSession)iftimezoneisnotNone:frompyspark.sql.pandas.typesimport_check_series_convert_timestamps_tz_localfrompandas.core.dtypes.commonimportis_datetime64tz_dtype,is_timedelta64_dtypecopied=Falseifisinstance(schema,StructType):forfieldinschema:# TODO: handle nested timestamps, such as ArrayType(TimestampType())?ifisinstance(field.dataType,TimestampType):s=_check_series_convert_timestamps_tz_local(pdf[field.name],timezone)ifsisnotpdf[field.name]:ifnotcopied:# Copy once if the series is modified to prevent the original# Pandas DataFrame from being updatedpdf=pdf.copy()copied=Truepdf[field.name]=selse:should_localize=notis_timestamp_ntz_preferred()forcolumn,seriesinpdf.items():s=seriesifshould_localizeandis_datetime64tz_dtype(s.dtype)ands.dt.tzisnotNone:s=_check_series_convert_timestamps_tz_local(series,timezone)ifsisnotseries:ifnotcopied:# Copy once if the series is modified to prevent the original# Pandas DataFrame from being updatedpdf=pdf.copy()copied=Truepdf[column]=sforcolumn,seriesinpdf.items():ifis_timedelta64_dtype(series):ifnotcopied:pdf=pdf.copy()copied=True# Explicitly set the timedelta as object so the output of numpy records can# hold the timedelta instances as are. Otherwise, it converts to the internal# numeric values.ser=pdf[column]pdf[column]=pd.Series(ser.dt.to_pytimedelta(),index=ser.index,dtype="object",name=ser.name)# Convert pandas.DataFrame to list of numpy recordsnp_records=pdf.to_records(index=False)# Check if any columns need to be fixed for Spark to infer properlyiflen(np_records)>0:record_dtype=self._get_numpy_record_dtype(np_records[0])ifrecord_dtypeisnotNone:return[r.astype(record_dtype).tolist()forrinnp_records]# Convert list of numpy records to python listsreturn[r.tolist()forrinnp_records]def_get_numpy_record_dtype(self,rec:"np.recarray")->Optional["np.dtype"]:""" Used when converting a pandas.DataFrame to Spark using to_records(), this will correct the dtypes of fields in a record so they can be properly loaded into Spark. Parameters ---------- rec : numpy.record a numpy record to check field dtypes Returns ------- numpy.dtype corrected dtype for a numpy.record or None if no correction needed """importnumpyasnpcur_dtypes=rec.dtypecol_names=cur_dtypes.namesrecord_type_list=[]has_rec_fix=Falseforiinrange(len(cur_dtypes)):curr_type=cur_dtypes[i]# If type is a datetime64 timestamp, convert to microseconds# NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs,# conversion from [us] or lower will lead to py datetime objects, see SPARK-22417ifcurr_type==np.dtype("datetime64[ns]"):curr_type="datetime64[us]"has_rec_fix=Truerecord_type_list.append((str(col_names[i]),curr_type))returnnp.dtype(record_type_list)ifhas_rec_fixelseNonedef_create_from_pandas_with_arrow(self,pdf:"PandasDataFrameLike",schema:Union[StructType,List[str]],timezone:str)->"DataFrame":""" Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """frompyspark.sqlimportSparkSessionfrompyspark.sql.dataframeimportDataFrameassertisinstance(self,SparkSession)frompyspark.sql.pandas.serializersimportArrowStreamPandasSerializerfrompyspark.sql.typesimportTimestampTypefrompyspark.sql.pandas.typesimportfrom_arrow_type,to_arrow_typefrompyspark.sql.pandas.utilsimport(require_minimum_pandas_version,require_minimum_pyarrow_version,)require_minimum_pandas_version()require_minimum_pyarrow_version()frompandas.api.typesimport(# type: ignore[attr-defined]is_datetime64_dtype,is_datetime64tz_dtype,)importpyarrowaspa# Create the Spark schema from list of names passed in with Arrow typesifisinstance(schema,(list,tuple)):arrow_schema=pa.Schema.from_pandas(pdf,preserve_index=False)struct=StructType()prefer_timestamp_ntz=is_timestamp_ntz_preferred()forname,fieldinzip(schema,arrow_schema):struct.add(name,from_arrow_type(field.type,prefer_timestamp_ntz),nullable=field.nullable)schema=struct# Determine arrow types to coerce data when creating batchesifisinstance(schema,StructType):arrow_types=[to_arrow_type(f.dataType)forfinschema.fields]elifisinstance(schema,DataType):raiseValueError("Single data type %s is not supported with Arrow"%str(schema))else:# Any timestamps must be coerced to be compatible with Sparkarrow_types=[to_arrow_type(TimestampType())ifis_datetime64_dtype(t)oris_datetime64tz_dtype(t)elseNonefortinpdf.dtypes]# Slice the DataFrame to be batchedstep=self._jconf.arrowMaxRecordsPerBatch()step=stepifstep>0elselen(pdf)pdf_slices=(pdf.iloc[start:start+step]forstartinrange(0,len(pdf),step))# Create list of Arrow (columns, type) for serializer dump_streamarrow_data=[[(c,t)for(_,c),tinzip(pdf_slice.items(),arrow_types)]forpdf_sliceinpdf_slices]jsparkSession=self._jsparkSessionsafecheck=self._jconf.arrowSafeTypeConversion()col_by_name=True# col by name only applies to StructType columns, can't happen hereser=ArrowStreamPandasSerializer(timezone,safecheck,col_by_name)@no_type_checkdefreader_func(temp_filename):returnself._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)@no_type_checkdefcreate_iter_server():returnself._jvm.ArrowIteratorServer()# Create Spark DataFrame from Arrow stream file, using one batch per partitionjiter=self._sc._serialize_to_jvm(arrow_data,ser,reader_func,create_iter_server)assertself._jvmisnotNonejdf=self._jvm.PythonSQLUtils.toDataFrame(jiter,schema.json(),jsparkSession)df=DataFrame(jdf,self)df._schema=schemareturndfdef_test()->None:importdoctestfrompyspark.sqlimportSparkSessionimportpyspark.sql.pandas.conversionglobs=pyspark.sql.pandas.conversion.__dict__.copy()spark=(SparkSession.builder.master("local[4]").appName("sql.pandas.conversion tests").getOrCreate())globs["spark"]=spark(failure_count,test_count)=doctest.testmod(pyspark.sql.pandas.conversion,globs=globs,optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE|doctest.REPORT_NDIFF,)spark.stop()iffailure_count:sys.exit(-1)if__name__=="__main__":_test()